mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
151 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc803e01c7 | ||
|
|
eaa2460701 | ||
|
|
c7dbcc6483 | ||
|
|
ad8a5934e1 | ||
|
|
7078e6477e | ||
|
|
69475f5bf1 | ||
|
|
ddeeb9428c | ||
|
|
780c60630c | ||
|
|
40c37b1219 | ||
|
|
c14b09376a | ||
|
|
fbcf56b2ba | ||
|
|
2d369b32f9 | ||
|
|
d52c524fc2 | ||
|
|
c2b51fbe98 | ||
|
|
7f2ac589f9 | ||
|
|
dff3872897 | ||
|
|
4f4b92da7d | ||
|
|
18f171d885 | ||
|
|
c72f8acea1 | ||
|
|
abedbc726f | ||
|
|
3e8d389e3e | ||
|
|
8810f8a728 | ||
|
|
5de91b9d81 | ||
|
|
57bc2abf41 | ||
|
|
dd50514d17 | ||
|
|
ac4935bf79 | ||
|
|
c817862cf7 | ||
|
|
c3768aaa46 | ||
|
|
a85fcfe05f | ||
|
|
1890535d1b | ||
|
|
9bb52acc14 | ||
|
|
551fdf32c3 | ||
|
|
74008ce487 | ||
|
|
852481e14d | ||
|
|
25c8279f26 | ||
|
|
05c57b9c7b | ||
|
|
46cbae088e | ||
|
|
b824bbfce6 | ||
|
|
9ba4c3edca | ||
|
|
ed2eef1625 | ||
|
|
e9a641bde7 | ||
|
|
ae3965a2a7 | ||
|
|
700af1c96d | ||
|
|
66edc5af7b | ||
|
|
ed15f6808b | ||
|
|
dc37fd2ff6 | ||
|
|
f256660780 | ||
|
|
23b261de3f | ||
|
|
884e6bff5d | ||
|
|
220436244c | ||
|
|
c430cf481a | ||
|
|
9f8f27fbad | ||
|
|
e746829b5f | ||
|
|
a69b24a069 | ||
|
|
12567f55cd | ||
|
|
8090daca40 | ||
|
|
27ffd9fe3d | ||
|
|
ee5cec7530 | ||
|
|
589a90bfbc | ||
|
|
314a364f61 | ||
|
|
f770cd96c6 | ||
|
|
01df1c0cc4 | ||
|
|
334589af4e | ||
|
|
43ef635be3 | ||
|
|
47d61e2c02 | ||
|
|
8f6fc8daa1 | ||
|
|
01ebfc41f3 | ||
|
|
87163cff8b | ||
|
|
6d5f847edc | ||
|
|
afb8700a95 | ||
|
|
e60d18cfb3 | ||
|
|
92332eb96e | ||
|
|
d5263d442f | ||
|
|
7ad7cac0c2 | ||
|
|
06a9f51431 | ||
|
|
849bc24d20 | ||
|
|
423e6c229c | ||
|
|
9fc27403b2 | ||
|
|
2de9a51591 | ||
|
|
a8632b7329 | ||
|
|
9ff32fd4c0 | ||
|
|
a097c42579 | ||
|
|
68e0767404 | ||
|
|
e09966024c | ||
|
|
893c2fc08a | ||
|
|
2e9f7b5f91 | ||
|
|
7f8e05ccad | ||
|
|
c316c63dff | ||
|
|
683680e5c8 | ||
|
|
bf8088e225 | ||
|
|
5050971ac6 | ||
|
|
08c54dcf22 | ||
|
|
6a5f87d874 | ||
|
|
a876f2d3fb | ||
|
|
a75f5898e6 | ||
|
|
dbab72153f | ||
|
|
0d54609435 | ||
|
|
07aa000750 | ||
|
|
b5c60d7d62 | ||
|
|
defefd79c5 | ||
|
|
27834df444 | ||
|
|
5c020bed49 | ||
|
|
c775ec1255 | ||
|
|
7527436549 | ||
|
|
541539a144 | ||
|
|
74220bb52c | ||
|
|
8eb60baf3a | ||
|
|
4b47e8ecb0 | ||
|
|
76bac2c1c5 | ||
|
|
0fcdda7175 | ||
|
|
e4eb3e63e6 | ||
|
|
626d4b433a | ||
|
|
83c7e03d05 | ||
|
|
959561473c | ||
|
|
7209eb74cc | ||
|
|
53cc3583df | ||
|
|
82c2553f07 | ||
|
|
6f6f9b537f | ||
|
|
f407f5a686 | ||
|
|
6134619998 | ||
|
|
817a9268ff | ||
|
|
3beddf341e | ||
|
|
1892c82a60 | ||
|
|
3f339cda6f | ||
|
|
16ba1cec69 | ||
|
|
8bfa50e283 | ||
|
|
c4a11e5a5a | ||
|
|
3cc4939dd3 | ||
|
|
b5c7937f8d | ||
|
|
b5ff4e816f | ||
|
|
a7d302e196 | ||
|
|
45381b188c | ||
|
|
054fb3308c | ||
|
|
d42431d73a | ||
|
|
c639cb7d5d | ||
|
|
97e65bf93f | ||
|
|
36c8a4aee7 | ||
|
|
19340d82e6 | ||
|
|
058e442072 | ||
|
|
9577a9f38d | ||
|
|
786971d443 | ||
|
|
1e164b6ec3 | ||
|
|
41ecccb2a9 | ||
|
|
94441fa746 | ||
|
|
ccb0ef518a | ||
|
|
3032a47af4 | ||
|
|
1b75dbd4f2 | ||
|
|
dade23a414 | ||
|
|
313f3e8286 | ||
|
|
4dacc52bde | ||
|
|
b1dffe8d9a |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@ wd14_tagger_model
|
||||
venv
|
||||
*.egg-info
|
||||
build
|
||||
.vscode
|
||||
.vscode
|
||||
wandb
|
||||
|
||||
101
README.md
101
README.md
@@ -26,10 +26,11 @@ The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||
|
||||
## Links to how-to-use documents
|
||||
|
||||
All documents are in Japanese currently.
|
||||
Most of the documents are written in Japanese.
|
||||
|
||||
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
|
||||
* [Dataset config](./config_README-ja.md)
|
||||
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
|
||||
* [Chinese version](./train_README-zh.md)
|
||||
* [Dataset config](./config_README-ja.md)
|
||||
* [DreamBooth training guide](./train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
||||
* [training LoRA](./train_network_README-ja.md)
|
||||
@@ -127,31 +128,75 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
- 1 Apr. 2023, 2023/4/1:
|
||||
- Fix an issue that `merge_lora.py` does not work with the latest version.
|
||||
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
|
||||
- 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。
|
||||
- `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。
|
||||
- 31 Mar. 2023, 2023/3/31:
|
||||
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
|
||||
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
||||
- `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。
|
||||
- `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
||||
- 30 Mar. 2023, 2023/3/30:
|
||||
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
|
||||
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
|
||||
- Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported.
|
||||
- Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option.
|
||||
- Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec!
|
||||
- Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option.
|
||||
### 30 Apr. 2023, 2023/04/30
|
||||
|
||||
- [P+](https://prompt-plus.github.io/) の学習に対応しました。jakaline-dev氏に感謝します。
|
||||
- 詳細は [#327](https://github.com/kohya-ss/sd-scripts/pull/327) をご参照ください。
|
||||
- 学習には `train_textual_inversion_XTI.py` を使用します。使用法は `train_textual_inversion.py` とほぼ同じです。た
|
||||
だし学習中のサンプル生成には対応していません。
|
||||
- 画像生成には `gen_img_diffusers.py` を使用してください(Web UIは対応していないと思われます)。`--XTI_embeddings` オプションで学習したembeddingを指定してください。
|
||||
- `train_network.py` で起動時のRAM使用量を削減しました。[#332](https://github.com/kohya-ss/sd-scripts/pull/332) guaneec氏に感謝します。
|
||||
- `gen_img_diffusers.py` でLoRAの事前マージに対応しました。`--network_merge` オプションを指定してください。なおプロンプトオプションの `--am` は使用できなくなります。
|
||||
- Added Chinese translation of [DreamBooth guide](./train_db_README-zh.md) and [LoRA guide](./train_network_README-zh.md). [PR #459](https://github.com/kohya-ss/sd-scripts/pull/459) Thanks to tomj2ee!
|
||||
- Added [documentation](./gen_img_README-ja.md) for image generation script `gen_img_diffusers.py` (Japanese version only).
|
||||
- 中国語版の[DreamBoothガイド](./train_db_README-zh.md)と[LoRAガイド](./train_network_README-zh.md)が追加されました。 [PR #459](https://github.com/kohya-ss/sd-scripts/pull/459) tomj2ee氏に感謝します。
|
||||
- 画像生成スクリプト `gen_img_diffusers.py`の簡単な[ドキュメント](./gen_img_README-ja.md)を追加しました(日本語版のみ)。
|
||||
|
||||
### 26 Apr. 2023, 2023/04/26
|
||||
|
||||
- Added [Chinese translation](./train_README-zh.md) of training guide. [PR #445](https://github.com/kohya-ss/sd-scripts/pull/445) Thanks to tomj2ee!
|
||||
- `tag_images_by_wd14_tagger.py` can now get arguments from outside. [PR #453](https://github.com/kohya-ss/sd-scripts/pull/453) Thanks to mio2333!
|
||||
- 学習に関するドキュメントの[中国語版](./train_README-zh.md)が追加されました。 [PR #445](https://github.com/kohya-ss/sd-scripts/pull/445) tomj2ee氏に感謝します。
|
||||
- `tag_images_by_wd14_tagger.py`の引数を外部から取得できるようになりました。 [PR #453](https://github.com/kohya-ss/sd-scripts/pull/453) mio2333氏に感謝します。
|
||||
|
||||
### 25 Apr. 2023, 2023/04/25
|
||||
|
||||
- Please do not update for a while if you cannot revert the repository to the previous version when something goes wrong, because the model saving part has been changed.
|
||||
- Added `--save_every_n_steps` option to each training script. The model is saved every specified steps.
|
||||
- `--save_last_n_steps` option can be used to save only the specified number of models (old models will be deleted).
|
||||
- If you specify the `--save_state` option, the state will also be saved at the same time. You can specify the number of steps to keep the state with the `--save_last_n_steps_state` option (the same value as `--save_last_n_steps` is used if omitted).
|
||||
- You can use the epoch-based model saving and state saving options together.
|
||||
- Not tested in multi-GPU environment. Please report any bugs.
|
||||
- `--cache_latents_to_disk` option automatically enables `--cache_latents` option when specified. [#438](https://github.com/kohya-ss/sd-scripts/issues/438)
|
||||
- Fixed a bug in `gen_img_diffusers.py` where latents upscaler would fail with a batch size of 2 or more.
|
||||
|
||||
- モデル保存部分を変更していますので、何か不具合が起きた時にリポジトリを前のバージョンに戻せない場合には、しばらく更新を控えてください。
|
||||
- 各学習スクリプトに`--save_every_n_steps`オプションを追加しました。指定ステップごとにモデルを保存します。
|
||||
- `--save_last_n_steps`オプションに数値を指定すると、そのステップ数のモデルのみを保存します(古いモデルは削除されます)。
|
||||
- `--save_state`オプションを指定するとstateも同時に保存します。`--save_last_n_steps_state`オプションでstateを残すステップ数を指定できます(省略時は`--save_last_n_steps`と同じ値が使われます)。
|
||||
- エポックごとのモデル保存、state保存のオプションと共存できます。
|
||||
- マルチGPU環境でのテストを行っていないため、不具合等あればご報告ください。
|
||||
- `--cache_latents_to_disk`オプションが指定されたとき、`--cache_latents`オプションが自動的に有効になるようにしました。 [#438](https://github.com/kohya-ss/sd-scripts/issues/438)
|
||||
- `gen_img_diffusers.py`でlatents upscalerがバッチサイズ2以上でエラーとなる不具合を修正しました。
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
### Naming of LoRA
|
||||
|
||||
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
|
||||
|
||||
LoRA for Linear layers and Conv2d layers with 1x1 kernel
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
|
||||
|
||||
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
|
||||
|
||||
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
|
||||
|
||||
To use LoRA-C3Liar with Web UI, please use our extension.
|
||||
|
||||
### LoRAの名称について
|
||||
|
||||
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
|
||||
|
||||
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください。
|
||||
|
||||
## Sample image generation during training
|
||||
A prompt file might look like this, for example
|
||||
@@ -197,5 +242,3 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
`( )` や `[ ]` などの重みづけも動作します。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
90
fine_tune.py
90
fine_tune.py
@@ -21,7 +21,7 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -142,12 +142,14 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
training_models = []
|
||||
if args.gradient_checkpointing:
|
||||
@@ -231,9 +233,7 @@ def train(args):
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -260,7 +260,7 @@ def train(args):
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("finetuning")
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
@@ -275,7 +275,7 @@ def train(args):
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
@@ -284,10 +284,20 @@ def train(args):
|
||||
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
@@ -304,7 +314,8 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
@@ -341,6 +352,27 @@ def train(args):
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -366,21 +398,23 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end(
|
||||
args,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -391,7 +425,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import json
|
||||
import random
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
@@ -13,156 +14,185 @@ from torchvision.transforms.functional import InterpolationMode
|
||||
from blip.blip import blip_decoder
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
IMAGE_SIZE = 384
|
||||
|
||||
# 正方形でいいのか? という気がするがソースがそうなので
|
||||
IMAGE_TRANSFORM = transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])
|
||||
IMAGE_TRANSFORM = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 共通化したいが微妙に処理が異なる……
|
||||
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor = IMAGE_TRANSFORM(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor = IMAGE_TRANSFORM(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
return (tensor, img_path)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed # + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed # + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
if not os.path.exists("blip"):
|
||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||
if not os.path.exists("blip"):
|
||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||
|
||||
cwd = os.getcwd()
|
||||
print('Current Working Directory is: ', cwd)
|
||||
os.chdir('finetune')
|
||||
cwd = os.getcwd()
|
||||
print("Current Working Directory is: ", cwd)
|
||||
os.chdir("finetune")
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
if args.beam_search:
|
||||
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
||||
max_length=args.max_length, min_length=args.min_length)
|
||||
else:
|
||||
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||
with torch.no_grad():
|
||||
if args.beam_search:
|
||||
captions = model.generate(
|
||||
imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
|
||||
)
|
||||
else:
|
||||
captions = model.generate(
|
||||
imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
|
||||
)
|
||||
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingTransformDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingTransformDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
img_tensor, image_path = data
|
||||
if img_tensor is None:
|
||||
try:
|
||||
raw_image = Image.open(image_path)
|
||||
if raw_image.mode != 'RGB':
|
||||
raw_image = raw_image.convert("RGB")
|
||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
img_tensor, image_path = data
|
||||
if img_tensor is None:
|
||||
try:
|
||||
raw_image = Image.open(image_path)
|
||||
if raw_image.mode != "RGB":
|
||||
raw_image = raw_image.convert("RGB")
|
||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, img_tensor))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs.append((image_path, img_tensor))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--beam_search", action="store_true",
|
||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--caption_weights",
|
||||
type=str,
|
||||
default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
||||
)
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument(
|
||||
"--beam_search",
|
||||
action="store_true",
|
||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||
parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -11,141 +12,161 @@ from transformers.generation.utils import GenerationMixin
|
||||
import library.train_util as train_util
|
||||
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
PATTERN_REPLACE = [
|
||||
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
||||
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
|
||||
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
|
||||
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
|
||||
re.compile(r"with the number \d+ on (it|\w+ \w+)"),
|
||||
re.compile(r'with the words "'),
|
||||
re.compile(r'word \w+ on it'),
|
||||
re.compile(r'that says the word \w+ on it'),
|
||||
re.compile('that says\'the word "( on it)?'),
|
||||
re.compile(r"word \w+ on it"),
|
||||
re.compile(r"that says the word \w+ on it"),
|
||||
re.compile("that says'the word \"( on it)?"),
|
||||
]
|
||||
|
||||
# 誤検知しまくりの with the word xxxx を消す
|
||||
|
||||
|
||||
def remove_words(captions, debug):
|
||||
removed_caps = []
|
||||
for caption in captions:
|
||||
cap = caption
|
||||
for pat in PATTERN_REPLACE:
|
||||
cap = pat.sub("", cap)
|
||||
if debug and cap != caption:
|
||||
print(caption)
|
||||
print(cap)
|
||||
removed_caps.append(cap)
|
||||
return removed_caps
|
||||
removed_caps = []
|
||||
for caption in captions:
|
||||
cap = caption
|
||||
for pat in PATTERN_REPLACE:
|
||||
cap = pat.sub("", cap)
|
||||
if debug and cap != caption:
|
||||
print(caption)
|
||||
print(cap)
|
||||
removed_caps.append(cap)
|
||||
return removed_caps
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||
|
||||
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
||||
# ここより上で置き換えようとするとすごく大変
|
||||
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
||||
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
||||
if input_ids.size()[0] != curr_batch_size[0]:
|
||||
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
||||
return input_ids
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
||||
# ここより上で置き換えようとするとすごく大変
|
||||
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
||||
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
||||
if input_ids.size()[0] != curr_batch_size[0]:
|
||||
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
||||
return input_ids
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
|
||||
# できればcacheに依存せず明示的にダウンロードしたい
|
||||
print(f"loading GIT: {args.model_id}")
|
||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||
print("GIT loaded")
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = [im for _, im in path_imgs]
|
||||
# できればcacheに依存せず明示的にダウンロードしたい
|
||||
print(f"loading GIT: {args.model_id}")
|
||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||
print("GIT loaded")
|
||||
|
||||
curr_batch_size[0] = len(path_imgs)
|
||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = [im for _, im in path_imgs]
|
||||
|
||||
if args.remove_words:
|
||||
captions = remove_words(captions, args.debug)
|
||||
curr_batch_size[0] = len(path_imgs)
|
||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
if args.remove_words:
|
||||
captions = remove_words(captions, args.debug)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, image))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, image))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
|
||||
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--remove_words", action="store_true",
|
||||
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="microsoft/git-large-textcaps",
|
||||
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
||||
parser.add_argument(
|
||||
"--remove_words",
|
||||
action="store_true",
|
||||
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -12,7 +14,7 @@ from torchvision import transforms
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -23,245 +25,299 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def get_latents(vae, images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||
return latents
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||
return latents
|
||||
|
||||
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
else:
|
||||
base_name = image_key
|
||||
if flip:
|
||||
base_name += '_flip'
|
||||
return os.path.join(data_dir, base_name)
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
else:
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
if flip:
|
||||
base_name += "_flip"
|
||||
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name)
|
||||
else:
|
||||
return os.path.join(data_dir, base_name)
|
||||
|
||||
|
||||
def main(args):
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
|
||||
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
|
||||
if not args.bucket_no_upscale:
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
||||
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
img_ar_errors = []
|
||||
|
||||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
||||
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
|
||||
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
||||
np.savez(npz_file_name, latent)
|
||||
else:
|
||||
# remove existing flipped npz
|
||||
for image_key, _ in bucket:
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
||||
if os.path.isfile(npz_file_name):
|
||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||
os.remove(npz_file_name)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
bucket_counts = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
if data_entry[0] is None:
|
||||
continue
|
||||
|
||||
img_tensor, image_path = data_entry[0]
|
||||
if img_tensor is not None:
|
||||
image = transforms.functional.to_pil_image(img_tensor)
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
||||
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
)
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
print(
|
||||
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
|
||||
)
|
||||
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error resized size is small: {resized_size}, {reso}"
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
img_ar_errors = []
|
||||
|
||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
||||
if args.flip_aug:
|
||||
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
||||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
||||
assert (
|
||||
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
|
||||
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||
|
||||
found = True
|
||||
for npz_file in npz_files:
|
||||
if not os.path.exists(npz_file):
|
||||
found = False
|
||||
break
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
dat = np.load(npz_file)['arr_0']
|
||||
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
continue
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(
|
||||
args.train_data_dir, image_key, args.full_path, True, args.recursive
|
||||
)
|
||||
np.savez(npz_file_name, latent)
|
||||
else:
|
||||
# remove existing flipped npz
|
||||
for image_key, _ in bucket:
|
||||
npz_file_name = (
|
||||
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
||||
)
|
||||
if os.path.isfile(npz_file_name):
|
||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||
os.remove(npz_file_name)
|
||||
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
bucket.clear()
|
||||
|
||||
if resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
bucket_counts = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
if data_entry[0] is None:
|
||||
continue
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
||||
img_tensor, image_path = data_entry[0]
|
||||
if img_tensor is not None:
|
||||
image = transforms.functional.to_pil_image(img_tensor)
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
# バッチへ追加
|
||||
bucket_manager.add_image(reso, (image_key, image))
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
process_batch(False)
|
||||
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
|
||||
# 残りを処理する
|
||||
process_batch(True)
|
||||
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
||||
|
||||
bucket_manager.sort()
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert (
|
||||
resized_size[0] == reso[0] or resized_size[1] == reso[1]
|
||||
), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert (
|
||||
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
|
||||
), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
assert (
|
||||
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
|
||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
|
||||
if args.flip_aug:
|
||||
npz_files.append(
|
||||
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
||||
)
|
||||
|
||||
found = True
|
||||
for npz_file in npz_files:
|
||||
if not os.path.exists(npz_file):
|
||||
found = False
|
||||
break
|
||||
|
||||
dat = np.load(npz_file)["arr_0"]
|
||||
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
continue
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
|
||||
|
||||
if resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
|
||||
|
||||
assert (
|
||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
||||
|
||||
# バッチへ追加
|
||||
bucket_manager.add_image(reso, (image_key, image))
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
process_batch(False)
|
||||
|
||||
# 残りを処理する
|
||||
process_batch(True)
|
||||
|
||||
bucket_manager.sort()
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--flip_aug", action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||
parser.add_argument("--skip_existing", action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_resolution",
|
||||
type=str,
|
||||
default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)",
|
||||
)
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
default=64,
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
|
||||
)
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
from tensorflow.keras.models import load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
@@ -17,7 +18,7 @@ import library.train_util as train_util
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
|
||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||
SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
@@ -25,182 +26,278 @@ CSV_FILE = FILES[-1]
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
|
||||
# pad to square
|
||||
size = max(image.shape[0:2])
|
||||
pad_x = size - image.shape[1]
|
||||
pad_y = size - image.shape[0]
|
||||
pad_l = pad_x // 2
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
||||
# pad to square
|
||||
size = max(image.shape[0:2])
|
||||
pad_x = size - image.shape[1]
|
||||
pad_y = size - image.shape[0]
|
||||
pad_l = pad_x // 2
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
|
||||
|
||||
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
def __getitem__(self, idx):
|
||||
img_path = str(self.images[idx])
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
tensor = torch.tensor(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
tensor = torch.tensor(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
return (tensor, img_path)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
for file in FILES:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
||||
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
||||
else:
|
||||
print("using existing wd14 tagger model")
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
for file in FILES:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(args.model_dir, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
else:
|
||||
print("using existing wd14 tagger model")
|
||||
|
||||
# 画像を読み込む
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
# 画像を読み込む
|
||||
model = load_model(args.model_dir)
|
||||
|
||||
print("loading model and labels")
|
||||
model = load_model(args.model_dir)
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
||||
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
||||
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
||||
|
||||
# 推論する
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
# 画像を読み込む
|
||||
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
tag_freq = {}
|
||||
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
tag_text = ""
|
||||
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||
if p >= args.thresh and i < len(tags):
|
||||
tag_text += ", " + tags[i]
|
||||
undesired_tags = set(args.undesired_tags.split(","))
|
||||
|
||||
if len(tag_text) > 0:
|
||||
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(tag_text + '\n')
|
||||
if args.debug:
|
||||
print(image_path, tag_text)
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
combined_tags = []
|
||||
general_tag_text = ""
|
||||
character_tag_text = ""
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
|
||||
tag_name = tag_name.replace("_", " ")
|
||||
|
||||
image, image_path = data
|
||||
if image is not None:
|
||||
image = image.detach().numpy()
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += ", " + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
if args.remove_underscore and len(tag_name) > 3:
|
||||
tag_name = tag_name.replace("_", " ")
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += ", " + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 先頭のカンマを取る
|
||||
if len(general_tag_text) > 0:
|
||||
general_tag_text = general_tag_text[2:]
|
||||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[2:]
|
||||
|
||||
tag_text = ", ".join(combined_tags)
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is not None:
|
||||
image = image.detach().numpy()
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
print("\nTag frequencies:")
|
||||
for tag, freq in sorted_tags:
|
||||
print(f"{tag}: {freq}")
|
||||
|
||||
print("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
||||
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
||||
parser.add_argument("--force_download", action='store_true',
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
default=DEFAULT_WD14_TAGGER_REPO,
|
||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default="wd14_tagger_model",
|
||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
||||
)
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument(
|
||||
"--general_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
)
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
parser.add_argument(
|
||||
"--remove_underscore",
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
default="",
|
||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||
)
|
||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.general_threshold is None:
|
||||
args.general_threshold = args.thresh
|
||||
if args.character_threshold is None:
|
||||
args.character_threshold = args.thresh
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
452
gen_img_README-ja.md
Normal file
452
gen_img_README-ja.md
Normal file
@@ -0,0 +1,452 @@
|
||||
SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNet(v1.0のみ動作確認)などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
|
||||
|
||||
# 概要
|
||||
|
||||
* Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
|
||||
* SD 1.xおよび2.x (base/v-parameterization)モデルに対応。
|
||||
* txt2img、img2img、inpaintingに対応。
|
||||
* 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
|
||||
* プロンプト1行あたりの生成枚数を指定可能。
|
||||
* 全体の繰り返し回数を指定可能。
|
||||
* `fp16`だけでなく`bf16`にも対応。
|
||||
* xformersに対応し高速生成が可能。
|
||||
* xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
|
||||
* プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
|
||||
* Diffusersの各種samplerに対応(Web UIよりもsampler数は少ないです)。
|
||||
* Text Encoderのclip skip(最後からn番目の層の出力を用いる)に対応。
|
||||
* VAEの別途読み込み。
|
||||
* CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
|
||||
* Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません。
|
||||
* LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
||||
* Text EncoderとU-Netで別の適用率を指定することはできません。
|
||||
* Attention Coupleに対応。
|
||||
* ControlNet v1.0に対応。
|
||||
* 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
|
||||
* 個人的に欲しくなった機能をいろいろ追加。
|
||||
|
||||
機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
|
||||
|
||||
# 基本的な使い方
|
||||
|
||||
## 対話モードでの画像生成
|
||||
|
||||
以下のように入力してください。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
||||
```
|
||||
|
||||
`--ckpt`オプションにモデル(Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ)、`--outdir`オプションに画像の出力先フォルダを指定します。
|
||||
|
||||
`--xformers`オプションでxformersの使用を指定します(xformersを使わない場合は外してください)。`--fp16`オプションでfp16(単精度)での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
||||
|
||||
`--interactive`オプションで対話モードを指定しています。
|
||||
|
||||
Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル(`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
|
||||
|
||||
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
`Type prompt:`と表示されたらプロンプトを入力してください。
|
||||
|
||||

|
||||
|
||||
※画像が表示されずエラーになる場合、headless(画面表示機能なし)のOpenCVがインストールされているかもしれません。`pip install opencv-python`として通常のOpenCVを入れてください。または`--no_preview`オプションで画像表示を止めてください。
|
||||
|
||||
画像ウィンドウを選択してから何らかのキーを押すとウィンドウが閉じ、次のプロンプトが入力できます。プロンプトでCtrl+Z、エンターの順に打鍵するとスクリプトを閉じます。
|
||||
|
||||
## 単一のプロンプトで画像を一括生成
|
||||
|
||||
以下のように入力します(実際には1行で入力します)。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
|
||||
```
|
||||
|
||||
`--images_per_prompt`オプションで、プロンプト1件当たりの生成枚数を指定します。`--prompt`オプションでプロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
|
||||
|
||||
`--batch_size`オプションでバッチサイズを指定できます(後述)。
|
||||
|
||||
## ファイルからプロンプトを読み込み一括生成
|
||||
|
||||
以下のように入力します。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --from_file <プロンプトファイル名>
|
||||
```
|
||||
|
||||
`--from_file`オプションで、プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。`--images_per_prompt`オプションを指定して1行あたり生成枚数を指定できます。
|
||||
|
||||
## ネガティブプロンプト、重みづけの使用
|
||||
|
||||
プロンプトオプション(プロンプト内で`--x`のように指定、後述)で`--n`を書くと、以降がネガティブプロンプトとなります。
|
||||
|
||||
またAUTOMATIC1111氏のWeb UIと同様の `()` や` []` 、`(xxx:1.3)` などによる重みづけが可能です(実装はDiffusersの[Long Prompt Weighting Stable Diffusion](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#long-prompt-weighting-stable-diffusion)からコピーしたものです)。
|
||||
|
||||
コマンドラインからのプロンプト指定、ファイルからのプロンプト読み込みでも同様に指定できます。
|
||||
|
||||

|
||||
|
||||
# 主なオプション
|
||||
|
||||
コマンドラインから指定してください。
|
||||
|
||||
## モデルの指定
|
||||
|
||||
- `--ckpt <モデル名>`:モデル名を指定します。`--ckpt`オプションは必須です。Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ、Hugging FaceのモデルIDを指定できます。
|
||||
|
||||
- `--v2`:Stable Diffusion 2.x系のモデルを使う場合に指定します。1.x系の場合には指定不要です。
|
||||
|
||||
- `--v_parameterization`:v-parameterizationを使うモデルを使う場合に指定します(`768-v-ema.ckpt`およびそこからの追加学習モデル、Waifu Diffusion v1.5など)。
|
||||
|
||||
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
- `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
|
||||
|
||||
## 画像生成と出力
|
||||
|
||||
- `--interactive`:インタラクティブモードで動作します。プロンプトを入力すると画像が生成されます。
|
||||
|
||||
- `--prompt <プロンプト>`:プロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
|
||||
|
||||
- `--from_file <プロンプトファイル名>`:プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。なお画像サイズやguidance scaleはプロンプトオプション(後述)で指定できます。
|
||||
|
||||
- `--W <画像幅>`:画像の幅を指定します。デフォルトは`512`です。
|
||||
|
||||
- `--H <画像高さ>`:画像の高さを指定します。デフォルトは`512`です。
|
||||
|
||||
- `--steps <ステップ数>`:サンプリングステップ数を指定します。デフォルトは`50`です。
|
||||
|
||||
- `--scale <ガイダンススケール>`:unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
|
||||
|
||||
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です(後ろの三つはk_lms、k_euler、k_euler_aでも指定できます)。
|
||||
|
||||
- `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
|
||||
|
||||
- `--images_per_prompt <生成枚数>`:プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
|
||||
|
||||
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
|
||||
|
||||
- `--max_embeddings_multiples <倍数>`:CLIPの入出力長をデフォルト(75)の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
|
||||
|
||||
- `--negative_scale` : uncoditioningのguidance scaleを個別に指定します。[gcem156氏のこちらの記事](https://note.com/gcem156/n/ne9a53e4a6f43)を参考に実装したものです。
|
||||
|
||||
## メモリ使用量や生成速度の調整
|
||||
|
||||
- `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
|
||||
|
||||
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
|
||||
VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてください。
|
||||
|
||||
- `--xformers`:xformersを使う場合に指定します。
|
||||
|
||||
- `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
||||
|
||||
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
||||
|
||||
## 追加ネットワーク(LoRA等)の使用
|
||||
|
||||
- `--network_module`:使用する追加ネットワークを指定します。LoRAの場合は`--network_module networks.lora`と指定します。複数のLoRAを使用する場合は`--network_module networks.lora networks.lora networks.lora`のように指定します。
|
||||
|
||||
- `--network_weights`:使用する追加ネットワークの重みファイルを指定します。`--network_weights model.safetensors`のように指定します。複数のLoRAを使用する場合は`--network_weights model1.safetensors model2.safetensors model3.safetensors`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
||||
|
||||
- `--network_mul`:使用する追加ネットワークの重みを何倍にするかを指定します。デフォルトは`1`です。`--network_mul 0.8`のように指定します。複数のLoRAを使用する場合は`--network_mul 0.4 0.5 0.7`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
||||
|
||||
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。プロンプトオプションの`--am`は使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
|
||||
|
||||
# 主なオプションの指定例
|
||||
|
||||
次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 64
|
||||
--prompt "beautiful flowers --n monochrome"
|
||||
```
|
||||
|
||||
次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 10
|
||||
--from_file prompts.txt
|
||||
```
|
||||
|
||||
Textual Inversion(後述)およびLoRAの使用例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.safetensors
|
||||
--scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --fp16 --sampler k_euler_a
|
||||
--textual_inversion_embeddings goodembed.safetensors negprompt.pt
|
||||
--network_module networks.lora networks.lora
|
||||
--network_weights model1.safetensors model2.safetensors
|
||||
--network_mul 0.4 0.8
|
||||
--clip_skip 2 --max_embeddings_multiples 1
|
||||
--batch_size 8 --images_per_prompt 1 --interactive
|
||||
```
|
||||
|
||||
# プロンプトオプション
|
||||
|
||||
プロンプト内で、`--n`のように「ハイフンふたつ+アルファベットn文字」でプロンプトから各種オプションの指定が可能です。対話モード、コマンドライン、ファイル、いずれからプロンプトを指定する場合でも有効です。
|
||||
|
||||
プロンプトのオプション指定`--n`の前後にはスペースを入れてください。
|
||||
|
||||
- `--n`:ネガティブプロンプトを指定します。
|
||||
|
||||
- `--w`:画像幅を指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--h`:画像高さを指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--s`:ステップ数を指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--d`:この画像の乱数seedを指定します。`--images_per_prompt`を指定している場合は「--d 1,2,3,4」のようにカンマ区切りで複数指定してください。
|
||||
※様々な理由により、Web UIとは同じ乱数seedでも生成される画像が異なる場合があります。
|
||||
|
||||
- `--l`:guidance scaleを指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--t`:img2img(後述)のstrengthを指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--nl`:ネガティブプロンプトのguidance scaleを指定します(後述)。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
|
||||
|
||||
※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
|
||||
|
||||
例:
|
||||
```
|
||||
(masterpiece, best quality), 1girl, in shirt and plated skirt, standing at street under cherry blossoms, upper body, [from below], kind smile, looking at another, [goodembed] --n realistic, real life, (negprompt), (lowres:1.1), (worst quality:1.2), (low quality:1.1), bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, normal quality, jpeg artifacts, signature, watermark, username, blurry --w 960 --h 640 --s 28 --d 1
|
||||
```
|
||||
|
||||

|
||||
|
||||
# img2img
|
||||
|
||||
## オプション
|
||||
|
||||
- `--image_path`:img2imgに利用する画像を指定します。`--image_path template.png`のように指定します。フォルダを指定すると、そのフォルダの画像を順次利用します。
|
||||
|
||||
- `--strength`:img2imgのstrengthを指定します。`--strength 0.8`のように指定します。デフォルトは`0.8`です。
|
||||
|
||||
- `--sequential_file_name`:ファイル名を連番にするかどうかを指定します。指定すると生成されるファイル名が`im_000001.png`からの連番になります。
|
||||
|
||||
- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名と同じになります。
|
||||
|
||||
## コマンドラインからの実行例
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
|
||||
--image_path template.png --strength 0.8
|
||||
--prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
|
||||
sailor school uniform, outdoors
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers, cropped,
|
||||
worst quality, low quality, normal quality, jpeg artifacts, (blurry),
|
||||
hair ornament, glasses"
|
||||
--batch_size 8 --images_per_prompt 32
|
||||
```
|
||||
|
||||
`--image_path`オプションにフォルダを指定すると、そのフォルダの画像を順次読み込みます。生成される枚数は画像枚数ではなく、プロンプト数になりますので、`--images_per_promptPPオプションを指定してimg2imgする画像の枚数とプロンプト数を合わせてください。
|
||||
|
||||
ファイルはファイル名でソートして読み込みます。なおソート順は文字列順となりますので(`1.jpg→2.jpg→10.jpg`ではなく`1.jpg→10.jpg→2.jpg`の順)、頭を0埋めするなどしてご対応ください(`01.jpg→02.jpg→10.jpg`)。
|
||||
|
||||
## img2imgを利用したupscale
|
||||
|
||||
img2img時にコマンドラインオプションの`--W`と`--H`で生成画像サイズを指定すると、元画像をそのサイズにリサイズしてからimg2imgを行います。
|
||||
|
||||
またimg2imgの元画像がこのスクリプトで生成した画像の場合、プロンプトを省略すると、元画像のメタデータからプロンプトを取得しそのまま用います。これによりHighres. fixの2nd stageの動作だけを行うことができます。
|
||||
|
||||
## img2img時のinpainting
|
||||
|
||||
画像およびマスク画像を指定してinpaintingできます(inpaintingモデルには対応しておらず、単にマスク領域を対象にimg2imgするだけです)。
|
||||
|
||||
オプションは以下の通りです。
|
||||
|
||||
- `--mask_image`:マスク画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。
|
||||
|
||||
マスク画像はグレースケール画像で、白の部分がinpaintingされます。境界をグラデーションしておくとなんとなく滑らかになりますのでお勧めです。
|
||||
|
||||

|
||||
|
||||
# その他の機能
|
||||
|
||||
## Textual Inversion
|
||||
|
||||
`--textual_inversion_embeddings`オプションで使用するembeddingsを指定します(複数指定可)。拡張子を除いたファイル名をプロンプト内で使用することで、そのembeddingsを利用します(Web UIと同様の使用法です)。ネガティブプロンプト内でも使用できます。
|
||||
|
||||
モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル(画像埋め込みは非対応)を利用できます
|
||||
|
||||
## Extended Textual Inversion
|
||||
|
||||
`--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
|
||||
|
||||
## Highres. fix
|
||||
|
||||
AUTOMATIC1111氏のWeb UIにある機能の類似機能です(独自実装のためもしかしたらいろいろ異なるかもしれません)。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
|
||||
|
||||
2nd stageのstep数は`--steps` と`--strength`オプションの値から計算されます(`steps*strength`)。
|
||||
|
||||
img2imgと併用できません。
|
||||
|
||||
以下のオプションがあります。
|
||||
|
||||
- `--highres_fix_scale`:Highres. fixを有効にして、1st stageで生成する画像のサイズを、倍率で指定します。最終出力が1024x1024で、最初に512x512の画像を生成する場合は`--highres_fix_scale 0.5`のように指定します。Web UI出の指定の逆数になっていますのでご注意ください。
|
||||
|
||||
- `--highres_fix_steps`:1st stageの画像のステップ数を指定します。デフォルトは`28`です。
|
||||
|
||||
- `--highres_fix_save_1st`:1st stageの画像を保存するかどうかを指定します。
|
||||
|
||||
- `--highres_fix_latents_upscaling`:指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingします(bilinearのみ対応)。未指定時は画像をLANCZOS4でupscalingします。
|
||||
|
||||
- `--highres_fix_upscaler`:2nd stageに任意のupscalerを利用します。現在は`--highres_fix_upscaler tools.latent_upscaler` のみ対応しています。
|
||||
|
||||
- `--highres_fix_upscaler_args`:`--highres_fix_upscaler`で指定したupscalerに渡す引数を指定します。
|
||||
`tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
|
||||
--steps 48 --sampler ddim --fp16
|
||||
--xformers
|
||||
--images_per_prompt 1 --interactive
|
||||
--highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5
|
||||
```
|
||||
|
||||
## ControlNet
|
||||
|
||||
現在はControlNet 1.0のみ動作確認しています。プリプロセスはCannyのみサポートしています。
|
||||
|
||||
以下のオプションがあります。
|
||||
|
||||
- `--control_net_models`:ControlNetのモデルファイルを指定します。
|
||||
複数指定すると、それらをstepごとに切り替えて利用します(Web UIのControlNet拡張の実装と異なります)。diffと通常の両方をサポートします。
|
||||
|
||||
- `--guide_image_path`:ControlNetに使うヒント画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。Canny以外のモデルの場合には、あらかじめプリプロセスを行っておいてください。
|
||||
|
||||
- `--control_net_preps`:ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
|
||||
cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
|
||||
|
||||
- `--control_net_weights`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
- `--control_net_ratios`:ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --bf16 --sampler k_euler_a
|
||||
--control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
|
||||
--guide_image_path guide.png --control_net_ratios 1.0 --interactive
|
||||
```
|
||||
|
||||
## Attention Couple + Reginal LoRA
|
||||
|
||||
プロンプトをいくつかの部分に分割し、それぞれのプロンプトを画像内のどの領域に適用するかを指定できる機能です。個別のオプションはありませんが、`mask_path`とプロンプトで指定します。
|
||||
|
||||
まず、プロンプトで` AND `を利用して、複数部分を定義します。最初の3つに対して領域指定ができ、以降の部分は画像全体へ適用されます。ネガティブプロンプトは画像全体に適用されます。
|
||||
|
||||
以下ではANDで3つの部分を定義しています。
|
||||
|
||||
```
|
||||
shs 2girls, looking at viewer, smile AND bsb 2girls, looking back AND 2girls --n bad quality, worst quality
|
||||
```
|
||||
|
||||
次にマスク画像を用意します。マスク画像はカラーの画像で、RGBの各チャネルがプロンプトのANDで区切られた部分に対応します。またあるチャネルの値がすべて0の場合、画像全体に適用されます。
|
||||
|
||||
上記の例では、Rチャネルが`shs 2girls, looking at viewer, smile`、Gチャネルが`bsb 2girls, looking back`に、Bチャネルが`2girls`に対応します。次のようなマスク画像を使用すると、Bチャネルに指定がありませんので、`2girls`は画像全体に適用されます。
|
||||
|
||||

|
||||
|
||||
マスク画像は`--mask_path`で指定します。現在は1枚のみ対応しています。指定した画像サイズに自動的にリサイズされ適用されます。
|
||||
|
||||
ControlNetと組み合わせることも可能です(細かい位置指定にはControlNetとの組み合わせを推奨します)。
|
||||
|
||||
LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
|
||||
|
||||
## CLIP Guided Stable Diffusion
|
||||
|
||||
DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
|
||||
|
||||
通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします(私のざっくりとした理解です)。大きめのCLIPを使いますのでVRAM使用量はかなり増加し(VRAM 8GBでは512*512でも厳しいかもしれません)、生成時間も掛かります。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
|
||||
|
||||
デフォルトではプロンプトの先頭75トークン(重みづけの特殊文字を除く)がCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できます(たとえばCLIPはDreamBoothのidentifier(識別子)や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます)。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
|
||||
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
|
||||
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
|
||||
--interactive --clip_guidance_scale 100
|
||||
```
|
||||
|
||||
## CLIP Image Guided Stable Diffusion
|
||||
|
||||
テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
|
||||
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
|
||||
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100
|
||||
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
||||
```
|
||||
|
||||
### VGG16 Guided Stable Diffusion
|
||||
|
||||
指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします(通常の生成では画像がぼやけた感じになります)。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
|
||||
--xformers --sampler ddim --fp16 --W 512 --H 704
|
||||
--batch_size 1 --images_per_prompt 1
|
||||
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers,
|
||||
cropped, worst quality, low quality, normal quality,
|
||||
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
|
||||
--strength 0.8 --image_path ..\src_image
|
||||
--vgg16_guidance_scale 100 --guide_image_path ..\src_image
|
||||
```
|
||||
|
||||
`--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できます(デフォルトは20でconv4-2のReLUです)。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
|
||||
|
||||

|
||||
|
||||
# その他のオプション
|
||||
|
||||
- `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
|
||||
|
||||
- `--n_iter` : 生成を繰り返す回数を指定します。デフォルトは1です。プロンプトをファイルから読み込むとき、複数回の生成を行いたい場合に指定します。
|
||||
|
||||
- `--tokenizer_cache_dir` : トークナイザーのキャッシュディレクトリを指定します。(作業中)
|
||||
|
||||
- `--seed` : 乱数seedを指定します。1枚生成時はその画像のseed、複数枚生成時は各画像のseedを生成するための乱数のseedになります(`--from_file`で複数画像生成するとき、`--seed`オプションを指定すると複数回実行したときに各画像が同じseedになります)。
|
||||
|
||||
- `--iter_same_seed` : プロンプトに乱数seedの指定がないとき、`--n_iter`の繰り返し内ではすべて同じseedを使います。`--from_file`で指定した複数のプロンプト間でseedを統一して比較するときに使います。
|
||||
|
||||
- `--diffusers_xformers` : Diffuserのxformersを使用します。
|
||||
|
||||
- `--opt_channels_last` : 推論時にテンソルのチャンネルを最後に配置します。場合によっては高速化されることがあります。
|
||||
|
||||
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
|
||||
|
||||
@@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
from networks.lora import LoRANetwork
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
|
||||
@@ -634,6 +635,7 @@ class PipelineLike:
|
||||
img2img_noise=None,
|
||||
clip_prompts=None,
|
||||
clip_guide_images=None,
|
||||
networks: Optional[List[LoRANetwork]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -717,6 +719,7 @@ class PipelineLike:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
reginonal_network = " AND " in prompt[0]
|
||||
|
||||
vae_batch_size = (
|
||||
batch_size
|
||||
@@ -942,7 +945,7 @@ class PipelineLike:
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
if init_image.size()[2:] == (height // 8, width // 8):
|
||||
if init_image.size()[-2:] == (height // 8, width // 8):
|
||||
init_latents = init_image
|
||||
else:
|
||||
if vae_batch_size >= batch_size:
|
||||
@@ -1010,6 +1013,11 @@ class PipelineLike:
|
||||
|
||||
# predict the noise residual
|
||||
if self.control_nets:
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
noise_pred = original_control_net.call_unet_and_control_net(
|
||||
i,
|
||||
num_latent_input,
|
||||
@@ -1019,7 +1027,7 @@ class PipelineLike:
|
||||
i / len(timesteps),
|
||||
latent_model_input,
|
||||
t,
|
||||
text_embeddings,
|
||||
text_emb_last,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
@@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
# split the prompts with "AND". each prompt must have the same number of splits
|
||||
new_prompts = []
|
||||
for p in prompt:
|
||||
new_prompts.extend(p.split(" AND "))
|
||||
prompt = new_prompts
|
||||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
||||
if uncond_prompt is not None:
|
||||
@@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
|
||||
negative_scale: float
|
||||
strength: float
|
||||
network_muls: Tuple[float]
|
||||
num_sub_prompts: int
|
||||
|
||||
|
||||
class BatchData(NamedTuple):
|
||||
@@ -2275,16 +2290,22 @@ def main(args):
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if not args.network_merge:
|
||||
mergiable = hasattr(network, "merge_to")
|
||||
if args.network_merge and not mergiable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergiable:
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
|
||||
if args.opt_channels_last:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
@@ -2292,11 +2313,27 @@ def main(args):
|
||||
|
||||
networks.append(network)
|
||||
else:
|
||||
network.merge_to(text_encoder, unet, dtype, device)
|
||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||
|
||||
else:
|
||||
networks = []
|
||||
|
||||
# upscalerの指定があれば取得する
|
||||
upscaler = None
|
||||
if args.highres_fix_upscaler:
|
||||
print("import upscaler module:", args.highres_fix_upscaler)
|
||||
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
||||
|
||||
us_kwargs = {}
|
||||
if args.highres_fix_upscaler_args:
|
||||
for net_arg in args.highres_fix_upscaler_args.split(";"):
|
||||
key, value = net_arg.split("=")
|
||||
us_kwargs[key] = value
|
||||
|
||||
print("create upscaler")
|
||||
upscaler = imported_module.create_upscaler(**us_kwargs)
|
||||
upscaler.to(dtype).to(device)
|
||||
|
||||
# ControlNetの処理
|
||||
control_nets: List[ControlNetInfo] = []
|
||||
if args.control_net_models:
|
||||
@@ -2347,12 +2384,12 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Extended Textual Inversion および Textual Inversionを処理する
|
||||
if args.XTI_embeddings:
|
||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds = []
|
||||
for embeds_file in args.textual_inversion_embeddings:
|
||||
@@ -2556,16 +2593,22 @@ def main(args):
|
||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||
|
||||
regional_network = False
|
||||
if networks and mask_images:
|
||||
# mask を領域情報として流用する、現在は1枚だけ対応
|
||||
# TODO 複数のnetwork classの混在時の考慮
|
||||
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||
regional_network = True
|
||||
print("use mask as region")
|
||||
# import cv2
|
||||
# for i in range(3):
|
||||
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
||||
# cv2.waitKey()
|
||||
# cv2.destroyAllWindows()
|
||||
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
||||
|
||||
size = None
|
||||
for i, network in enumerate(networks):
|
||||
if i < 3:
|
||||
np_mask = np.array(mask_images[0])
|
||||
np_mask = np_mask[:, :, i]
|
||||
size = np_mask.shape
|
||||
else:
|
||||
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
|
||||
network.set_region(i, i == len(networks) - 1, mask)
|
||||
mask_images = None
|
||||
|
||||
prev_image = None # for VGG16 guided
|
||||
@@ -2612,6 +2655,8 @@ def main(args):
|
||||
# highres_fixの処理
|
||||
if highres_fix and not highres_1st:
|
||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
||||
|
||||
print("process 1st stage")
|
||||
batch_1st = []
|
||||
for _, base, ext in batch:
|
||||
@@ -2621,14 +2666,41 @@ def main(args):
|
||||
height_1st = height_1st - height_1st % 32
|
||||
|
||||
ext_1st = BatchDataExt(
|
||||
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
|
||||
width_1st,
|
||||
height_1st,
|
||||
args.highres_fix_steps,
|
||||
ext.scale,
|
||||
ext.negative_scale,
|
||||
ext.strength,
|
||||
ext.network_muls,
|
||||
ext.num_sub_prompts,
|
||||
)
|
||||
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
||||
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
print("process 2nd stage")
|
||||
if args.highres_fix_latents_upscaling:
|
||||
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
||||
|
||||
if upscaler:
|
||||
# upscalerを使って画像を拡大する
|
||||
lowreso_imgs = None if is_1st_latent else images_1st
|
||||
lowreso_latents = None if not is_1st_latent else images_1st
|
||||
|
||||
# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
|
||||
batch_size = len(images_1st)
|
||||
vae_batch_size = (
|
||||
batch_size
|
||||
if args.vae_batch_size is None
|
||||
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
|
||||
)
|
||||
vae_batch_size = int(vae_batch_size)
|
||||
images_1st = upscaler.upscale(
|
||||
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
|
||||
)
|
||||
|
||||
elif args.highres_fix_latents_upscaling:
|
||||
# latentを拡大する
|
||||
org_dtype = images_1st.dtype
|
||||
if images_1st.dtype == torch.bfloat16:
|
||||
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
||||
@@ -2637,10 +2709,12 @@ def main(args):
|
||||
) # , antialias=True)
|
||||
images_1st = images_1st.to(org_dtype)
|
||||
|
||||
else:
|
||||
# 画像をLANCZOSで拡大する
|
||||
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
|
||||
|
||||
batch_2nd = []
|
||||
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||
if not args.highres_fix_latents_upscaling:
|
||||
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
|
||||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
@@ -2649,7 +2723,7 @@ def main(args):
|
||||
(
|
||||
return_latents,
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls),
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
|
||||
) = batch[0]
|
||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||
|
||||
@@ -2741,8 +2815,11 @@ def main(args):
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
shared = {}
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
if regional_network:
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||
|
||||
images = pipe(
|
||||
prompts,
|
||||
@@ -2967,11 +3044,26 @@ def main(args):
|
||||
print("Use previous image as guide image.")
|
||||
guide_image = prev_image
|
||||
|
||||
if regional_network:
|
||||
num_sub_prompts = len(prompt.split(" AND "))
|
||||
assert (
|
||||
len(networks) <= num_sub_prompts
|
||||
), "Number of networks must be less than or equal to number of sub prompts."
|
||||
else:
|
||||
num_sub_prompts = None
|
||||
|
||||
b1 = BatchData(
|
||||
False,
|
||||
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataExt(
|
||||
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
|
||||
width,
|
||||
height,
|
||||
steps,
|
||||
scale,
|
||||
negative_scale,
|
||||
strength,
|
||||
tuple(network_muls) if network_muls else None,
|
||||
num_sub_prompts,
|
||||
),
|
||||
)
|
||||
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||
@@ -3177,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_upscaler_args",
|
||||
type=str,
|
||||
default=None,
|
||||
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
|
||||
)
|
||||
@@ -3195,6 +3297,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
nargs="*",
|
||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -445,7 +445,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
|
||||
try:
|
||||
n_repeats = int(tokens[0])
|
||||
except ValueError as e:
|
||||
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
||||
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
||||
return 0, ""
|
||||
caption_by_folder = '_'.join(tokens[1:])
|
||||
return n_repeats, caption_by_folder
|
||||
@@ -486,7 +486,8 @@ def load_user_config(file: str) -> dict:
|
||||
|
||||
if file.name.lower().endswith('.json'):
|
||||
try:
|
||||
config = json.load(file)
|
||||
with open(file, 'r') as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
raise
|
||||
|
||||
@@ -1,18 +1,344 @@
|
||||
import torch
|
||||
import argparse
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
def add_custom_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
|
||||
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
||||
parser.add_argument(
|
||||
"--min_snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
||||
)
|
||||
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
text_weight = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = tokenizer(word).input_ids[1:-1]
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||
for i in range(len(tokens)):
|
||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||||
if no_boseos_middle:
|
||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||
else:
|
||||
w = []
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
clip_skip: int,
|
||||
eos: int,
|
||||
pad: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||
if max_embeddings_multiples > 1:
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
if pad == eos: # v1
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
else: # v2
|
||||
for j in range(len(text_input_chunk)):
|
||||
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
||||
text_input_chunk[j, -1] = eos
|
||||
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
||||
text_input_chunk[j, 1] = eos
|
||||
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embedding = text_encoder(text_input_chunk)[0]
|
||||
else:
|
||||
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
return text_embeddings
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt: Union[str, List[str]],
|
||||
device,
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
clip_skip=None,
|
||||
):
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||||
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
||||
|
||||
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
||||
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
||||
ending token in each of the chunk in the middle.
|
||||
skip_parsing (`bool`, *optional*, defaults to `False`):
|
||||
Skip the parsing of brackets.
|
||||
skip_weighting (`bool`, *optional*, defaults to `False`):
|
||||
Skip the weighting. When the parsing is skipped, it is forced True.
|
||||
"""
|
||||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
||||
|
||||
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||||
max_length = max([len(token) for token in prompt_tokens])
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
# pad the length of tokens and weights
|
||||
bos = tokenizer.bos_token_id
|
||||
eos = tokenizer.eos_token_id
|
||||
pad = tokenizer.pad_token_id
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=tokenizer.model_max_length,
|
||||
)
|
||||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt_tokens,
|
||||
tokenizer.model_max_length,
|
||||
clip_skip,
|
||||
eos,
|
||||
pad,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
78
library/huggingface_util.py
Normal file
78
library/huggingface_util.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import *
|
||||
from huggingface_hub import HfApi
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from library.utils import fire_in_thread
|
||||
|
||||
|
||||
def exists_repo(
|
||||
repo_id: str, repo_type: str, revision: str = "main", token: str = None
|
||||
):
|
||||
api = HfApi(
|
||||
token=token,
|
||||
)
|
||||
try:
|
||||
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def upload(
|
||||
args: argparse.Namespace,
|
||||
src: Union[str, Path, bytes, BinaryIO],
|
||||
dest_suffix: str = "",
|
||||
force_sync_upload: bool = False,
|
||||
):
|
||||
repo_id = args.huggingface_repo_id
|
||||
repo_type = args.huggingface_repo_type
|
||||
token = args.huggingface_token
|
||||
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
||||
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||
api = HfApi(token=token)
|
||||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||
isinstance(src, Path) and src.is_dir()
|
||||
)
|
||||
|
||||
def uploader():
|
||||
if is_folder:
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
folder_path=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
else:
|
||||
api.upload_file(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
path_or_fileobj=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
|
||||
if args.async_upload and not force_sync_upload:
|
||||
fire_in_thread(uploader)
|
||||
else:
|
||||
uploader()
|
||||
|
||||
|
||||
def list_dir(
|
||||
repo_id: str,
|
||||
subfolder: str,
|
||||
repo_type: str,
|
||||
revision: str = "main",
|
||||
token: str = None,
|
||||
):
|
||||
api = HfApi(
|
||||
token=token,
|
||||
)
|
||||
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||
file_list = [
|
||||
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
|
||||
]
|
||||
return file_list
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
import pathlib
|
||||
@@ -49,6 +50,7 @@ from diffusers import (
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -58,6 +60,7 @@ from torch import einsum
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -71,6 +74,11 @@ LAST_STATE_NAME = "{}-state"
|
||||
DEFAULT_EPOCH_NAME = "epoch"
|
||||
DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
|
||||
DEFAULT_STEP_NAME = "at"
|
||||
STEP_STATE_NAME = "{}-step{:08d}-state"
|
||||
STEP_FILE_NAME = "{}-step{:08d}"
|
||||
STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
|
||||
|
||||
# region dataset
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
||||
@@ -487,7 +495,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
@@ -719,7 +727,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# ちょっと速くした
|
||||
print("caching latents.")
|
||||
|
||||
@@ -737,11 +745,38 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if info.latents_npz is not None:
|
||||
info.latents = self.load_latents_from_npz(info, False)
|
||||
info.latents = torch.FloatTensor(info.latents)
|
||||
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
||||
|
||||
# might be None, but that's ok because check is done in dataset
|
||||
info.latents_flipped = self.load_latents_from_npz(info, True)
|
||||
if info.latents_flipped is not None:
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
# TODO: refactor to unify with FineTuningDataset
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
|
||||
if not is_main_process:
|
||||
continue
|
||||
|
||||
cache_available = False
|
||||
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
if os.path.exists(info.latents_npz):
|
||||
cached_latents = np.load(info.latents_npz)["arr_0"]
|
||||
if cached_latents.shape[1:3] == expected_latents_size:
|
||||
cache_available = True
|
||||
|
||||
if subset.flip_aug:
|
||||
cache_available = False
|
||||
if os.path.exists(info.latents_npz_flipped):
|
||||
cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"]
|
||||
if cached_latents_flipped.shape[1:3] == expected_latents_size:
|
||||
cache_available = True
|
||||
|
||||
if cache_available:
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
@@ -757,6 +792,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# iterate batches
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
images = []
|
||||
@@ -770,14 +808,21 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
|
||||
for info, latent in zip(batch, latents):
|
||||
info.latents = latent
|
||||
if cache_to_disk:
|
||||
np.savez(info.latents_npz, latent.float().numpy())
|
||||
else:
|
||||
info.latents = latent
|
||||
|
||||
if subset.flip_aug:
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
info.latents_flipped = latent
|
||||
if cache_to_disk:
|
||||
np.savez(info.latents_npz_flipped, latent.float().numpy())
|
||||
else:
|
||||
info.latents_flipped = latent
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
@@ -805,9 +850,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# 画像サイズはsizeより大きいのでリサイズする
|
||||
face_size = max(face_w, face_h)
|
||||
size = min(self.height, self.width) # 短いほう
|
||||
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
||||
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
||||
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
||||
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
||||
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
||||
if min_scale >= max_scale: # range指定がmin==max
|
||||
scale = min_scale
|
||||
else:
|
||||
@@ -832,7 +878,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
||||
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
|
||||
if face_size > self.size // 10 and face_size >= 40:
|
||||
if face_size > size // 10 and face_size >= 40:
|
||||
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
||||
|
||||
p1 = max(0, min(p1, length - target_size))
|
||||
@@ -870,10 +916,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None:
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
image = None
|
||||
elif image_info.latents_npz is not None:
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||
latents = torch.FloatTensor(latents)
|
||||
image = None
|
||||
@@ -950,10 +996,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["images"] = images
|
||||
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
example["captions"] = captions
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
example["captions"] = captions
|
||||
return example
|
||||
|
||||
|
||||
@@ -1160,19 +1206,27 @@ class FineTuningDataset(BaseDataset):
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(subset.image_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||
abs_path = abs_path[0]
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
@@ -1337,10 +1391,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size)
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
@@ -1397,8 +1451,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||
if os.name == "nt": # only windows
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
k = cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
if k == 27 or k == ord("s") or k == ord("e"):
|
||||
break
|
||||
steps += 1
|
||||
@@ -1441,7 +1495,6 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region モジュール入れ替え部
|
||||
"""
|
||||
高速化のためのモジュール入れ替え
|
||||
@@ -1896,6 +1949,38 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
||||
parser.add_argument(
|
||||
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huggingface_path_in_repo",
|
||||
type=str,
|
||||
default=None,
|
||||
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
|
||||
)
|
||||
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
||||
parser.add_argument(
|
||||
"--huggingface_repo_visibility",
|
||||
type=str,
|
||||
default=None,
|
||||
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_huggingface",
|
||||
action="store_true",
|
||||
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async_upload",
|
||||
action="store_true",
|
||||
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
@@ -1906,18 +1991,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_n_epoch_ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)",
|
||||
)
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument(
|
||||
"--save_last_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_last_n_epochs_state",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)",
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_last_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_last_n_steps_state",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_state",
|
||||
@@ -1988,7 +2093,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_with",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["tensorboard", "wandb", "all"],
|
||||
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
|
||||
)
|
||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||
parser.add_argument(
|
||||
"--log_tracker_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb_api_key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset",
|
||||
type=float,
|
||||
@@ -2061,6 +2185,12 @@ def verify_training_args(args: argparse.Namespace):
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
if args.cache_latents_to_disk and not args.cache_latents:
|
||||
args.cache_latents = True
|
||||
print(
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
)
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
@@ -2110,9 +2240,14 @@ def add_dataset_arguments(
|
||||
parser.add_argument(
|
||||
"--cache_latents",
|
||||
action="store_true",
|
||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)",
|
||||
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ",
|
||||
)
|
||||
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--cache_latents_to_disk",
|
||||
action="store_true",
|
||||
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
|
||||
)
|
||||
@@ -2204,7 +2339,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
args_dict = vars(args)
|
||||
|
||||
# remove unnecessary keys
|
||||
for key in ["config_file", "output_config"]:
|
||||
for key in ["config_file", "output_config", "wandb_api_key"]:
|
||||
if key in args_dict:
|
||||
del args_dict[key]
|
||||
|
||||
@@ -2261,6 +2396,57 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
# region utils
|
||||
|
||||
|
||||
def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
if not args.resume:
|
||||
return
|
||||
|
||||
if not args.resume_from_huggingface:
|
||||
print(f"resume training from local state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
return
|
||||
|
||||
print(f"resume training from huggingface state: {args.resume}")
|
||||
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
||||
path_in_repo = "/".join(args.resume.split("/")[2:])
|
||||
revision = None
|
||||
repo_type = None
|
||||
if ":" in path_in_repo:
|
||||
divided = path_in_repo.split(":")
|
||||
if len(divided) == 2:
|
||||
path_in_repo, revision = divided
|
||||
repo_type = "model"
|
||||
else:
|
||||
path_in_repo, revision, repo_type = divided
|
||||
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
|
||||
|
||||
list_files = huggingface_util.list_dir(
|
||||
repo_id=repo_id,
|
||||
subfolder=path_in_repo,
|
||||
revision=revision,
|
||||
token=args.huggingface_token,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
async def download(filename) -> str:
|
||||
def task():
|
||||
return hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
token=args.huggingface_token,
|
||||
)
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(None, task)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
|
||||
if len(results) == 0:
|
||||
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
|
||||
dirname = os.path.dirname(results[0])
|
||||
accelerator.load_state(dirname)
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||
|
||||
@@ -2460,7 +2646,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
Unified API to get any scheduler from its name.
|
||||
"""
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps = args.lr_warmup_steps
|
||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
|
||||
num_cycles = args.lr_scheduler_num_cycles
|
||||
power = args.lr_scheduler_power
|
||||
@@ -2484,6 +2670,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
|
||||
lr_scheduler_kwargs[key] = value
|
||||
|
||||
def wrap_check_needless_num_warmup_steps(return_vals):
|
||||
if num_warmup_steps is not None and num_warmup_steps != 0:
|
||||
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
|
||||
return return_vals
|
||||
|
||||
# using any lr_scheduler from other library
|
||||
if args.lr_scheduler_type:
|
||||
lr_scheduler_type = args.lr_scheduler_type
|
||||
@@ -2496,7 +2687,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
lr_scheduler_type = values[-1]
|
||||
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||
return lr_scheduler
|
||||
return wrap_check_needless_num_warmup_steps(lr_scheduler)
|
||||
|
||||
if name.startswith("adafactor"):
|
||||
assert (
|
||||
@@ -2504,12 +2695,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
||||
initial_lr = float(name.split(":")[1])
|
||||
# print("adafactor scheduler init lr", initial_lr)
|
||||
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
||||
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
@@ -2592,13 +2783,32 @@ def load_tokenizer(args: argparse.Namespace):
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
if args.logging_dir is None:
|
||||
log_with = None
|
||||
logging_dir = None
|
||||
else:
|
||||
log_with = "tensorboard"
|
||||
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
||||
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
|
||||
if args.log_with is None:
|
||||
if logging_dir is not None:
|
||||
log_with = "tensorboard"
|
||||
else:
|
||||
log_with = None
|
||||
else:
|
||||
log_with = args.log_with
|
||||
if log_with in ["tensorboard", "all"]:
|
||||
if logging_dir is None:
|
||||
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
|
||||
if log_with in ["wandb", "all"]:
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
if logging_dir is not None:
|
||||
os.makedirs(logging_dir, exist_ok=True)
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -2640,7 +2850,7 @@ def prepare_dtype(args: argparse.Namespace):
|
||||
return weight_dtype, save_dtype
|
||||
|
||||
|
||||
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
|
||||
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
@@ -2724,26 +2934,53 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
||||
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
||||
return model_name, ckpt_name
|
||||
def default_if_none(value, default):
|
||||
return default if value is None else value
|
||||
|
||||
|
||||
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
if saving:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
save_func()
|
||||
|
||||
if args.save_last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
remove_old_func(remove_epoch_no)
|
||||
return saving
|
||||
def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext
|
||||
|
||||
|
||||
def save_sd_model_on_epoch_end(
|
||||
def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
return STEP_FILE_NAME.format(model_name, step_no) + ext
|
||||
|
||||
|
||||
def get_last_ckpt_name(args: argparse.Namespace, ext: str):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
return model_name + ext
|
||||
|
||||
|
||||
def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
|
||||
if args.save_last_n_epochs is None:
|
||||
return None
|
||||
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
if remove_epoch_no < 0:
|
||||
return None
|
||||
return remove_epoch_no
|
||||
|
||||
|
||||
def get_remove_step_no(args: argparse.Namespace, step_no: int):
|
||||
if args.save_last_n_steps is None:
|
||||
return None
|
||||
|
||||
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
|
||||
# save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する
|
||||
remove_step_no = step_no - args.save_last_n_steps - 1
|
||||
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
|
||||
if remove_step_no < 0:
|
||||
return None
|
||||
return remove_step_no
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||
def save_sd_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
src_path: str,
|
||||
save_stable_diffusion_format: bool,
|
||||
@@ -2756,54 +2993,92 @@ def save_sd_model_on_epoch_end(
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
epoch_no = epoch + 1
|
||||
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
||||
if on_epoch_end:
|
||||
epoch_no = epoch + 1
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
if not saving:
|
||||
return
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
|
||||
def save_sd():
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
def remove_sd(old_epoch_no):
|
||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
save_func = save_sd
|
||||
remove_old_func = remove_sd
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
remove_no = get_remove_epoch_no(args, epoch_no)
|
||||
else:
|
||||
# 保存するか否かは呼び出し側で判断済み
|
||||
|
||||
def save_du():
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
epoch_no = epoch # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される
|
||||
remove_no = get_remove_step_no(args, global_step)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if save_stable_diffusion_format:
|
||||
ext = ".safetensors" if use_safetensors else ".ckpt"
|
||||
|
||||
if on_epoch_end:
|
||||
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
|
||||
else:
|
||||
ckpt_name = get_step_ckpt_name(args, ext, global_step)
|
||||
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||
|
||||
# remove older checkpoints
|
||||
if remove_no is not None:
|
||||
if on_epoch_end:
|
||||
remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no)
|
||||
else:
|
||||
remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no)
|
||||
|
||||
remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name)
|
||||
if os.path.exists(remove_ckpt_file):
|
||||
print(f"removing old checkpoint: {remove_ckpt_file}")
|
||||
os.remove(remove_ckpt_file)
|
||||
|
||||
else:
|
||||
if on_epoch_end:
|
||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
||||
print(f"saving model: {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
||||
|
||||
def remove_du(old_epoch_no):
|
||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||
if os.path.exists(out_dir_old):
|
||||
print(f"removing old model: {out_dir_old}")
|
||||
shutil.rmtree(out_dir_old)
|
||||
print(f"saving model: {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||
|
||||
save_func = save_du
|
||||
remove_old_func = remove_du
|
||||
# remove older checkpoints
|
||||
if remove_no is not None:
|
||||
if on_epoch_end:
|
||||
remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
|
||||
else:
|
||||
remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
|
||||
|
||||
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
||||
if os.path.exists(remove_out_dir):
|
||||
print(f"removing old model: {remove_out_dir}")
|
||||
shutil.rmtree(remove_out_dir)
|
||||
|
||||
if on_epoch_end:
|
||||
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
|
||||
else:
|
||||
save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||
def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
|
||||
print(f"saving state at epoch {epoch_no}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||
accelerator.save_state(state_dir)
|
||||
if args.save_state_to_huggingface:
|
||||
print("uploading state to huggingface.")
|
||||
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||
|
||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||
if last_n_epochs is not None:
|
||||
@@ -2814,6 +3089,45 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
|
||||
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
|
||||
print(f"saving state at step {step_no}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
|
||||
accelerator.save_state(state_dir)
|
||||
if args.save_state_to_huggingface:
|
||||
print("uploading state to huggingface.")
|
||||
huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
|
||||
|
||||
last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
|
||||
if last_n_steps is not None:
|
||||
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
|
||||
remove_step_no = step_no - last_n_steps - 1
|
||||
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
|
||||
|
||||
if remove_step_no > 0:
|
||||
state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
print(f"removing old state: {state_dir_old}")
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
print("saving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
|
||||
accelerator.save_state(state_dir)
|
||||
|
||||
if args.save_state_to_huggingface:
|
||||
print("uploading last state to huggingface.")
|
||||
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
|
||||
|
||||
|
||||
def save_sd_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
src_path: str,
|
||||
@@ -2826,7 +3140,7 @@ def save_sd_model_on_train_end(
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@@ -2838,6 +3152,8 @@ def save_sd_model_on_train_end(
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
||||
)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, model_name)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
@@ -2846,13 +3162,8 @@ def save_sd_model_on_train_end(
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
print("saving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
# scheduler:
|
||||
@@ -3041,6 +3352,18 @@ def sample_images(
|
||||
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -3084,7 +3407,7 @@ class collater_class:
|
||||
def __init__(self, epoch, step, dataset):
|
||||
self.current_epoch = epoch
|
||||
self.current_step = step
|
||||
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
||||
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
||||
|
||||
def __call__(self, examples):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
6
library/utils.py
Normal file
6
library/utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import threading
|
||||
from typing import *
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
450
networks/dylora.py
Normal file
450
networks/dylora.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# some codes are copied from:
|
||||
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
||||
|
||||
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Changes made to the original code:
|
||||
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
# NOTE: support dropout in future
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.unit = unit
|
||||
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
||||
|
||||
if self.is_conv2d and self.is_conv2d_3x3:
|
||||
kernel_size = org_module.kernel_size
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
||||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
||||
else:
|
||||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
||||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
||||
|
||||
# same as microsoft's
|
||||
for lora in self.lora_A:
|
||||
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
||||
for lora in self.lora_B:
|
||||
torch.nn.init.zeros_(lora)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
result = self.org_forward(x)
|
||||
|
||||
# specify the dynamic rank
|
||||
trainable_rank = random.randint(0, self.lora_dim - 1)
|
||||
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||||
|
||||
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||||
for i in range(0, trainable_rank):
|
||||
self.lora_A[i].requires_grad = False
|
||||
self.lora_B[i].requires_grad = False
|
||||
for i in range(trainable_rank, trainable_rank + self.unit):
|
||||
self.lora_A[i].requires_grad = True
|
||||
self.lora_B[i].requires_grad = True
|
||||
for i in range(trainable_rank + self.unit, self.lora_dim):
|
||||
self.lora_A[i].requires_grad = False
|
||||
self.lora_B[i].requires_grad = False
|
||||
|
||||
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
||||
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
||||
|
||||
# calculate with lora_A and lora_B
|
||||
if self.is_conv2d_3x3:
|
||||
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
||||
ab = torch.nn.functional.conv2d(ab, lora_B)
|
||||
else:
|
||||
ab = x
|
||||
if self.is_conv2d:
|
||||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
||||
|
||||
ab = torch.nn.functional.linear(ab, lora_A)
|
||||
ab = torch.nn.functional.linear(ab, lora_B)
|
||||
|
||||
if self.is_conv2d:
|
||||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
||||
|
||||
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||||
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||||
|
||||
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
|
||||
return result
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
# state dictを通常のLoRAと同じにする:
|
||||
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||||
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
||||
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
key_a = f"{self.lora_name}.lora_A.{i}"
|
||||
key_b = f"{self.lora_name}.lora_B.{i}"
|
||||
if key_a in sd:
|
||||
sd.pop(key_a)
|
||||
sd.pop(key_b)
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
return sd
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||||
|
||||
if lora_A_weight is None or lora_B_weight is None:
|
||||
if strict:
|
||||
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||||
else:
|
||||
return
|
||||
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||
|
||||
state_dict.update(
|
||||
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
||||
)
|
||||
state_dict.update(
|
||||
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
unit = kwargs.get("unit", None)
|
||||
if conv_dim is not None:
|
||||
conv_dim = int(conv_dim)
|
||||
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
unit = 1
|
||||
|
||||
network = DyLoRANetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
apply_to_conv=conv_dim is not None,
|
||||
unit=unit,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
if key not in modules_alpha:
|
||||
modules_alpha = modules_dim[key]
|
||||
|
||||
module_class = DyLoRAModule
|
||||
|
||||
network = DyLoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
apply_to_conv=False,
|
||||
modules_dim=None,
|
||||
modules_alpha=None,
|
||||
unit=1,
|
||||
module_class=DyLoRAModule,
|
||||
varbose=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
if self.apply_to_conv:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
||||
|
||||
# create module instances
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
loras = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
continue
|
||||
|
||||
# dropout and fan_in_fan_out is default
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.apply_to_conv:
|
||||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
"""
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
"""
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
# mask is a tensor with values from 0 to 1
|
||||
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||
pass
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
pass
|
||||
125
networks/extract_lora_from_dylora.py
Normal file
125
networks/extract_lora_from_dylora.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_state_dict(file_name):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = None
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def split_lora_model(lora_sd, unit):
|
||||
max_rank = 0
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
rank = value.size()[0]
|
||||
if rank > max_rank:
|
||||
max_rank = rank
|
||||
print(f"Max rank: {max_rank}")
|
||||
|
||||
rank = unit
|
||||
split_models = []
|
||||
new_alpha = None
|
||||
while rank < max_rank:
|
||||
print(f"Splitting rank {rank}")
|
||||
new_sd = {}
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
new_sd[key] = value[:rank].contiguous()
|
||||
elif "lora_up" in key:
|
||||
new_sd[key] = value[:, :rank].contiguous()
|
||||
else:
|
||||
# なぜかscaleするとおかしくなる……
|
||||
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||
# print(key, value.size(), this_rank, rank, value, scale)
|
||||
# new_alpha = value * scale # always same
|
||||
# new_sd[key] = new_alpha
|
||||
new_sd[key] = value
|
||||
|
||||
split_models.append((new_sd, rank, new_alpha))
|
||||
rank += unit
|
||||
|
||||
return max_rank, split_models
|
||||
|
||||
|
||||
def split(args):
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model)
|
||||
|
||||
print("Splitting Model...")
|
||||
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
for state_dict, new_rank, new_alpha in split_models:
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
new_metadata = {}
|
||||
else:
|
||||
new_metadata = metadata.copy()
|
||||
|
||||
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
||||
new_metadata["ss_network_dim"] = str(new_rank)
|
||||
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
filename, ext = os.path.splitext(args.save_to)
|
||||
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||
|
||||
print(f"saving model to: {model_file_name}")
|
||||
save_to_file(model_file_name, state_dict, new_metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
split(args)
|
||||
@@ -145,8 +145,8 @@ def svd(args):
|
||||
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
766
networks/lora.py
766
networks/lora.py
@@ -5,11 +5,13 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
|
||||
from library import train_util
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
@@ -58,8 +60,6 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.region = None
|
||||
self.region_mask = None
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
@@ -102,44 +102,194 @@ class LoRAModule(torch.nn.Module):
|
||||
self.region_mask = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.region is None:
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
# regional LoRA FIXME same as additional-network extension
|
||||
if x.size()[1] % 77 == 0:
|
||||
# print(f"LoRA for context: {self.lora_name}")
|
||||
self.region = None
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
# calculate region mask first time
|
||||
if self.region_mask is None:
|
||||
if len(x.size()) == 4:
|
||||
h, w = x.size()[2:4]
|
||||
else:
|
||||
seq_len = x.size()[1]
|
||||
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
||||
h = int(self.region.size()[0] / ratio + 0.5)
|
||||
w = seq_len // h
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
r = self.region.to(x.device)
|
||||
if r.dtype == torch.bfloat16:
|
||||
r = r.to(torch.float)
|
||||
r = r.unsqueeze(0).unsqueeze(1)
|
||||
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
||||
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
|
||||
r = r.to(x.dtype)
|
||||
# check regional or not by lora_name
|
||||
self.text_encoder = False
|
||||
if lora_name.startswith("lora_te_"):
|
||||
self.regional = False
|
||||
self.use_sub_prompt = True
|
||||
self.text_encoder = True
|
||||
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
||||
self.regional = False
|
||||
self.use_sub_prompt = True
|
||||
elif "time_emb" in lora_name:
|
||||
self.regional = False
|
||||
self.use_sub_prompt = False
|
||||
else:
|
||||
self.regional = True
|
||||
self.use_sub_prompt = False
|
||||
|
||||
if len(x.size()) == 3:
|
||||
r = torch.reshape(r, (1, x.size()[1], -1))
|
||||
self.network: LoRANetwork = None
|
||||
|
||||
self.region_mask = r
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
||||
def default_forward(self, x):
|
||||
# print("default_forward", self.lora_name, x.size())
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
if self.network is None or self.network.sub_prompt_index is None:
|
||||
return self.default_forward(x)
|
||||
if not self.regional and not self.use_sub_prompt:
|
||||
return self.default_forward(x)
|
||||
|
||||
if self.regional:
|
||||
return self.regional_forward(x)
|
||||
else:
|
||||
return self.sub_prompt_forward(x)
|
||||
|
||||
def get_mask_for_x(self, x):
|
||||
# calculate size from shape of x
|
||||
if len(x.size()) == 4:
|
||||
h, w = x.size()[2:4]
|
||||
area = h * w
|
||||
else:
|
||||
area = x.size()[1]
|
||||
|
||||
mask = self.network.mask_dic[area]
|
||||
if mask is None:
|
||||
raise ValueError(f"mask is None for resolution {area}")
|
||||
if len(x.size()) != 4:
|
||||
mask = torch.reshape(mask, (1, -1, 1))
|
||||
return mask
|
||||
|
||||
def regional_forward(self, x):
|
||||
if "attn2_to_out" in self.lora_name:
|
||||
return self.to_out_forward(x)
|
||||
|
||||
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
||||
return self.default_forward(x)
|
||||
|
||||
# apply mask for LoRA result
|
||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
mask = self.get_mask_for_x(lx)
|
||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
lx = lx * mask
|
||||
|
||||
x = self.org_forward(x)
|
||||
x = x + lx
|
||||
|
||||
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
||||
x = self.postp_to_q(x)
|
||||
|
||||
return x
|
||||
|
||||
def postp_to_q(self, x):
|
||||
# repeat x to num_sub_prompts
|
||||
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
||||
qc = self.network.batch_size # uncond
|
||||
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
||||
if has_real_uncond:
|
||||
qc += self.network.batch_size # real_uncond
|
||||
|
||||
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
||||
query[: self.network.batch_size] = x[: self.network.batch_size]
|
||||
|
||||
for i in range(self.network.batch_size):
|
||||
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
||||
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
||||
|
||||
if has_real_uncond:
|
||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||
|
||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
return query
|
||||
|
||||
def sub_prompt_forward(self, x):
|
||||
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
||||
return self.org_forward(x)
|
||||
|
||||
emb_idx = self.network.sub_prompt_index
|
||||
if not self.text_encoder:
|
||||
emb_idx += self.network.batch_size
|
||||
|
||||
# apply sub prompt of X
|
||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||
|
||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
|
||||
x = self.org_forward(x)
|
||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||
|
||||
return x
|
||||
|
||||
def to_out_forward(self, x):
|
||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
|
||||
if self.network.is_last_network:
|
||||
masks = [None] * self.network.num_sub_prompts
|
||||
self.network.shared[self.lora_name] = (None, masks)
|
||||
else:
|
||||
lx, masks = self.network.shared[self.lora_name]
|
||||
|
||||
# call own LoRA
|
||||
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
||||
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
||||
|
||||
if self.network.is_last_network:
|
||||
lx = torch.zeros(
|
||||
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
||||
)
|
||||
self.network.shared[self.lora_name] = (lx, masks)
|
||||
|
||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||
|
||||
# if not last network, return x and masks
|
||||
x = self.org_forward(x)
|
||||
if not self.network.is_last_network:
|
||||
return x
|
||||
|
||||
lx, masks = self.network.shared.pop(self.lora_name)
|
||||
|
||||
# if last network, combine separated x with mask weighted sum
|
||||
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
||||
|
||||
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
||||
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
||||
if has_real_uncond:
|
||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||
|
||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# for i in range(len(masks)):
|
||||
# if masks[i] is None:
|
||||
# masks[i] = torch.zeros_like(masks[-1])
|
||||
|
||||
mask = torch.cat(masks)
|
||||
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
||||
for i in range(self.network.batch_size):
|
||||
# 1枚の画像ごとに処理する
|
||||
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
||||
lx1 = lx1 * mask
|
||||
lx1 = torch.sum(lx1, dim=0)
|
||||
|
||||
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
||||
x1 = x[xi : xi + self.network.num_sub_prompts]
|
||||
x1 = x1 * mask
|
||||
x1 = torch.sum(x1, dim=0)
|
||||
x1 = x1 / mask_sum
|
||||
|
||||
x1 = x1 + lx1
|
||||
out[self.network.batch_size + i] = x1
|
||||
|
||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
return out
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = kwargs.get("conv_dim", None)
|
||||
@@ -151,34 +301,50 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
"""
|
||||
block_dims = kwargs.get("block_dims")
|
||||
block_alphas = None
|
||||
# block dim/alpha/lr
|
||||
block_dims = kwargs.get("block_dims", None)
|
||||
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||
|
||||
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
||||
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
||||
block_alphas = kwargs.get("block_alphas", None)
|
||||
conv_block_dims = kwargs.get("conv_block_dims", None)
|
||||
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
||||
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
)
|
||||
|
||||
# extract learning rate weight for each block
|
||||
if down_lr_weight is not None:
|
||||
# if some parameters are not set, use zero
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
# remove block dim/alpha without learning rate
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
)
|
||||
|
||||
if block_dims is not None:
|
||||
block_dims = [int(d) for d in block_dims.split(',')]
|
||||
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||
block_alphas = kwargs.get("block_alphas")
|
||||
if block_alphas is None:
|
||||
block_alphas = [1] * len(block_dims)
|
||||
else:
|
||||
block_alphas = [int(a) for a in block_alphas(',')]
|
||||
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||
|
||||
conv_block_dims = kwargs.get("conv_block_dims")
|
||||
conv_block_alphas = None
|
||||
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
||||
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||
conv_block_alphas = kwargs.get("conv_block_alphas")
|
||||
if conv_block_alphas is None:
|
||||
conv_block_alphas = [1] * len(conv_block_dims)
|
||||
else:
|
||||
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
||||
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||
"""
|
||||
block_alphas = None
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
@@ -187,11 +353,220 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
alpha=network_alpha,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
block_dims=block_dims,
|
||||
block_alphas=block_alphas,
|
||||
conv_block_dims=conv_block_dims,
|
||||
conv_block_alphas=conv_block_alphas,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
||||
# このメソッドは外部から呼び出される可能性を考慮しておく
|
||||
# network_dim, network_alpha にはデフォルト値が入っている。
|
||||
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
||||
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
||||
def get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
):
|
||||
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
||||
|
||||
def parse_ints(s):
|
||||
return [int(i) for i in s.split(",")]
|
||||
|
||||
def parse_floats(s):
|
||||
return [float(i) for i in s.split(",")]
|
||||
|
||||
# block_dimsとblock_alphasをパースする。必ず値が入る
|
||||
if block_dims is not None:
|
||||
block_dims = parse_ints(block_dims)
|
||||
assert (
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
block_alphas = parse_floats(block_alphas)
|
||||
assert (
|
||||
len(block_alphas) == num_total_blocks
|
||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(
|
||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||
)
|
||||
block_alphas = [network_alpha] * num_total_blocks
|
||||
|
||||
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims = parse_ints(conv_block_dims)
|
||||
assert (
|
||||
len(conv_block_dims) == num_total_blocks
|
||||
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
||||
|
||||
if conv_block_alphas is not None:
|
||||
conv_block_alphas = parse_floats(conv_block_alphas)
|
||||
assert (
|
||||
len(conv_block_alphas) == num_total_blocks
|
||||
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
print(
|
||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||
)
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
if conv_dim is not None:
|
||||
print(
|
||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||
)
|
||||
conv_block_dims = [conv_dim] * num_total_blocks
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
||||
def get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return None, None, None
|
||||
|
||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||
|
||||
def get_list(name_with_suffix) -> List[float]:
|
||||
import math
|
||||
|
||||
tokens = name_with_suffix.split("+")
|
||||
name = tokens[0]
|
||||
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "sine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
||||
elif name == "linear":
|
||||
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
||||
elif name == "reverse_linear":
|
||||
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
return None
|
||||
|
||||
if type(down_lr_weight) == str:
|
||||
down_lr_weight = get_list(down_lr_weight)
|
||||
if type(up_lr_weight) == str:
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
if down_lr_weight != None:
|
||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||
else:
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
print("mid_lr_weight:", mid_lr_weight)
|
||||
else:
|
||||
print("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||
else:
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
||||
def remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
):
|
||||
# set 0 to block dim without learning rate to remove the block
|
||||
if down_lr_weight != None:
|
||||
for i, lr in enumerate(down_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[i] = 0
|
||||
if mid_lr_weight != None:
|
||||
if mid_lr_weight == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if up_lr_weight != None:
|
||||
for i, lr in enumerate(up_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 外部から呼び出す可能性を考慮しておく
|
||||
def get_block_index(lora_name: str) -> int:
|
||||
block_idx = -1 # invalid lora name
|
||||
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
|
||||
return block_idx
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
@@ -220,13 +595,18 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
if key not in modules_alpha:
|
||||
modules_alpha = modules_dim[key]
|
||||
|
||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||
network.weights_sd = weights_sd
|
||||
return network
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
# is it possible to apply conv_in and conv_out?
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
|
||||
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
@@ -242,9 +622,23 @@ class LoRANetwork(torch.nn.Module):
|
||||
alpha=1,
|
||||
conv_lora_dim=None,
|
||||
conv_alpha=None,
|
||||
block_dims=None,
|
||||
block_alphas=None,
|
||||
conv_block_dims=None,
|
||||
conv_block_alphas=None,
|
||||
modules_dim=None,
|
||||
modules_alpha=None,
|
||||
module_class=LoRAModule,
|
||||
varbose=False,
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
1. lora_dimとalphaを指定
|
||||
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
||||
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
||||
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
||||
5. modules_dimとmodules_alphaを指定 (推論用)
|
||||
"""
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
@@ -255,62 +649,88 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
|
||||
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
||||
if self.apply_to_conv2d_3x3:
|
||||
if self.conv_alpha is None:
|
||||
self.conv_alpha = self.alpha
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
# TODO get block index here
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
if modules_dim is not None:
|
||||
if lora_name not in modules_dim:
|
||||
continue # no LoRA module in this weights file
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
block_idx = get_block_index(lora_name)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
elif conv_block_dims is not None:
|
||||
dim = conv_block_dims[block_idx]
|
||||
alpha = conv_block_alphas[block_idx]
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
elif self.apply_to_conv2d_3x3:
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
else:
|
||||
continue
|
||||
|
||||
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
return loras, skipped
|
||||
|
||||
self.text_encoder_loras = create_modules(
|
||||
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
)
|
||||
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.conv_lora_dim is not None:
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
self.weights_sd = None
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
print(
|
||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
print(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
self.mid_lr_weight: float = None
|
||||
self.block_lr = False
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
@@ -325,37 +745,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
from safetensors.torch import load_file
|
||||
|
||||
self.weights_sd = load_file(file)
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
self.weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
||||
if self.weights_sd:
|
||||
weights_has_text_encoder = weights_has_unet = False
|
||||
for key in self.weights_sd.keys():
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
weights_has_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||
weights_has_unet = True
|
||||
|
||||
if apply_text_encoder is None:
|
||||
apply_text_encoder = weights_has_text_encoder
|
||||
else:
|
||||
assert (
|
||||
apply_text_encoder == weights_has_text_encoder
|
||||
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
||||
|
||||
if apply_unet is None:
|
||||
apply_unet = weights_has_unet
|
||||
else:
|
||||
assert (
|
||||
apply_unet == weights_has_unet
|
||||
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
||||
else:
|
||||
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
@@ -370,17 +769,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
if self.weights_sd:
|
||||
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
||||
info = self.load_state_dict(self.weights_sd, False)
|
||||
print(f"weights are loaded: {info}")
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, dtype, device):
|
||||
assert self.weights_sd is not None, "weights are not loaded"
|
||||
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in self.weights_sd.keys():
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||
@@ -398,26 +790,53 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in self.weights_sd.keys():
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||
def set_block_lr_weight(
|
||||
self,
|
||||
up_lr_weight: List[float] = None,
|
||||
mid_lr_weight: float = None,
|
||||
down_lr_weight: List[float] = None,
|
||||
):
|
||||
self.block_lr = True
|
||||
self.down_lr_weight = down_lr_weight
|
||||
self.mid_lr_weight = mid_lr_weight
|
||||
self.up_lr_weight = up_lr_weight
|
||||
|
||||
def get_lr_weight(self, lora: LoRAModule) -> float:
|
||||
lr_weight = 1.0
|
||||
block_idx = get_block_index(lora.lora_name)
|
||||
if block_idx < 0:
|
||||
return lr_weight
|
||||
|
||||
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.down_lr_weight != None:
|
||||
lr_weight = self.down_lr_weight[block_idx]
|
||||
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.mid_lr_weight != None:
|
||||
lr_weight = self.mid_lr_weight
|
||||
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.up_lr_weight != None:
|
||||
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
||||
|
||||
return lr_weight
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
return params
|
||||
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
@@ -425,13 +844,39 @@ class LoRANetwork(torch.nn.Module):
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
if self.block_lr:
|
||||
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
block_idx_to_lora = {}
|
||||
for lora in self.unet_loras:
|
||||
idx = get_block_index(lora.lora_name)
|
||||
if idx not in block_idx_to_lora:
|
||||
block_idx_to_lora[idx] = []
|
||||
block_idx_to_lora[idx].append(lora)
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
|
||||
else:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
@@ -455,6 +900,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
@@ -467,17 +913,45 @@ class LoRANetwork(torch.nn.Module):
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
@staticmethod
|
||||
def set_regions(networks, image):
|
||||
image = image.astype(np.float32) / 255.0
|
||||
for i, network in enumerate(networks[:3]):
|
||||
# NOTE: consider averaging overwrapping area
|
||||
region = image[:, :, i]
|
||||
if region.max() == 0:
|
||||
continue
|
||||
region = torch.tensor(region)
|
||||
network.set_region(region)
|
||||
# mask is a tensor with values from 0 to 1
|
||||
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||
if mask.max() == 0:
|
||||
mask = torch.ones_like(mask)
|
||||
|
||||
def set_region(self, region):
|
||||
for lora in self.unet_loras:
|
||||
lora.set_region(region)
|
||||
self.mask = mask
|
||||
self.sub_prompt_index = sub_prompt_index
|
||||
self.is_last_network = is_last_network
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.set_network(self)
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
self.batch_size = batch_size
|
||||
self.num_sub_prompts = num_sub_prompts
|
||||
self.current_size = (height, width)
|
||||
self.shared = shared
|
||||
|
||||
# create masks
|
||||
mask = self.mask
|
||||
mask_dic = {}
|
||||
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
||||
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
||||
dtype = ref_weight.dtype
|
||||
device = ref_weight.device
|
||||
|
||||
def resize_add(mh, mw):
|
||||
# print(mh, mw, mh * mw)
|
||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||
m = m.to(device, dtype=dtype)
|
||||
mask_dic[mh * mw] = m
|
||||
|
||||
h = height // 8
|
||||
w = width // 8
|
||||
for _ in range(4):
|
||||
resize_add(h, w)
|
||||
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||
resize_add(h + h % 2, w + w % 2)
|
||||
h = (h + 1) // 2
|
||||
w = (w + 1) // 2
|
||||
|
||||
self.mask_dic = mask_dic
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from tqdm import tqdm
|
||||
from library import model_util
|
||||
import library.train_util as train_util
|
||||
import argparse
|
||||
from transformers import CLIPTokenizer
|
||||
import torch
|
||||
@@ -16,16 +17,20 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
weights_dtype = torch.float16
|
||||
|
||||
# いろいろ準備する
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
args.pretrained_model_name_or_path = args.sd_model
|
||||
args.vae = None
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
|
||||
|
||||
print(f"loading LoRA: {args.model}")
|
||||
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||
has_te_weight = False
|
||||
for key in network.weights_sd.keys():
|
||||
for key in weights_sd.keys():
|
||||
if 'lora_te' in key:
|
||||
has_te_weight = True
|
||||
break
|
||||
@@ -40,9 +45,9 @@ def interrogate(args):
|
||||
else:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
||||
|
||||
text_encoder.to(DEVICE)
|
||||
text_encoder.to(DEVICE, dtype=weights_dtype)
|
||||
text_encoder.eval()
|
||||
unet.to(DEVICE)
|
||||
unet.to(DEVICE, dtype=weights_dtype)
|
||||
unet.eval() # U-Netは呼び出さないので不要だけど
|
||||
|
||||
# トークンをひとつひとつ当たっていく
|
||||
@@ -78,9 +83,14 @@ def interrogate(args):
|
||||
orig_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||
network.to(DEVICE)
|
||||
info = network.load_state_dict(weights_sd, strict=False)
|
||||
print(f"Loading LoRA weights: {info}")
|
||||
|
||||
network.to(DEVICE, dtype=weights_dtype)
|
||||
network.eval()
|
||||
|
||||
del unet
|
||||
|
||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
print("get text encoder embeddings with lora.")
|
||||
lora_embs = get_all_embeddings(text_encoder)
|
||||
@@ -107,6 +117,7 @@ def interrogate(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
|
||||
@@ -21,6 +21,6 @@ fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
# tensorflow<2.11
|
||||
tensorflow==2.10.1
|
||||
huggingface-hub==0.12.0
|
||||
huggingface-hub==0.13.3
|
||||
# for kohya_ss library
|
||||
.
|
||||
|
||||
@@ -9,86 +9,122 @@ import library.model_util as model_util
|
||||
|
||||
|
||||
def convert(args):
|
||||
# 引数を確認する
|
||||
load_dtype = torch.float16 if args.fp16 else None
|
||||
# 引数を確認する
|
||||
load_dtype = torch.float16 if args.fp16 else None
|
||||
|
||||
save_dtype = None
|
||||
if args.fp16:
|
||||
save_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.float:
|
||||
save_dtype = torch.float
|
||||
save_dtype = None
|
||||
if args.fp16 or args.save_precision_as == "fp16":
|
||||
save_dtype = torch.float16
|
||||
elif args.bf16 or args.save_precision_as == "bf16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.float or args.save_precision_as == "float":
|
||||
save_dtype = torch.float
|
||||
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
|
||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
assert (
|
||||
is_save_ckpt or args.reference_model is not None
|
||||
), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
||||
)
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
|
||||
original_model, args.epoch, args.global_step, save_dtype, vae)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
||||
print(f"model saved.")
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
key_count = model_util.save_stable_diffusion_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
|
||||
)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
print(f"model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v1", action='store_true',
|
||||
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
||||
parser.add_argument("--fp16", action='store_true',
|
||||
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--float", action='store_true',
|
||||
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
||||
parser.add_argument("--global_step", type=int, default=0,
|
||||
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||
parser.add_argument("--reference_model", type=str, default=None,
|
||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
||||
parser.add_argument("--use_safetensors", action='store_true',
|
||||
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
|
||||
)
|
||||
parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
|
||||
parser.add_argument(
|
||||
"--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision_as",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
|
||||
)
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
|
||||
parser.add_argument(
|
||||
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_safetensors",
|
||||
action="store_true",
|
||||
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
|
||||
)
|
||||
|
||||
parser.add_argument("model_to_load", type=str, default=None,
|
||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||
parser.add_argument("model_to_save", type=str, default=None,
|
||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||
return parser
|
||||
parser.add_argument(
|
||||
"model_to_load",
|
||||
type=str,
|
||||
default=None,
|
||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_to_save",
|
||||
type=str,
|
||||
default=None,
|
||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
||||
348
tools/latent_upscaler.py
Normal file
348
tools/latent_upscaler.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# 外部から簡単にupscalerを呼ぶためのスクリプト
|
||||
# 単体で動くようにモデル定義も含めている
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import cv2
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
|
||||
|
||||
# initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += residual
|
||||
|
||||
out = self.relu2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Upscaler(nn.Module):
|
||||
def __init__(self):
|
||||
super(Upscaler, self).__init__()
|
||||
|
||||
# define layers
|
||||
# latent has 4 channels
|
||||
|
||||
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(128)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
# resblocks
|
||||
# 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
|
||||
self.resblock1 = ResidualBlock(128)
|
||||
self.resblock2 = ResidualBlock(128)
|
||||
self.resblock3 = ResidualBlock(128)
|
||||
self.resblock4 = ResidualBlock(128)
|
||||
self.resblock5 = ResidualBlock(128)
|
||||
self.resblock6 = ResidualBlock(128)
|
||||
self.resblock7 = ResidualBlock(128)
|
||||
self.resblock8 = ResidualBlock(128)
|
||||
self.resblock9 = ResidualBlock(128)
|
||||
self.resblock10 = ResidualBlock(128)
|
||||
self.resblock11 = ResidualBlock(128)
|
||||
self.resblock12 = ResidualBlock(128)
|
||||
self.resblock13 = ResidualBlock(128)
|
||||
self.resblock14 = ResidualBlock(128)
|
||||
self.resblock15 = ResidualBlock(128)
|
||||
self.resblock16 = ResidualBlock(128)
|
||||
self.resblock17 = ResidualBlock(128)
|
||||
self.resblock18 = ResidualBlock(128)
|
||||
self.resblock19 = ResidualBlock(128)
|
||||
self.resblock20 = ResidualBlock(128)
|
||||
|
||||
# last convs
|
||||
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(64)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(64)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
|
||||
# final conv: output 4 channels
|
||||
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
||||
|
||||
# initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# initialize final conv weights to 0: 流行りのzero conv
|
||||
nn.init.constant_(self.conv_final.weight, 0)
|
||||
|
||||
def forward(self, x):
|
||||
inp = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
|
||||
residual = x
|
||||
x = self.resblock1(x)
|
||||
x = self.resblock2(x)
|
||||
x = self.resblock3(x)
|
||||
x = self.resblock4(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock5(x)
|
||||
x = self.resblock6(x)
|
||||
x = self.resblock7(x)
|
||||
x = self.resblock8(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock9(x)
|
||||
x = self.resblock10(x)
|
||||
x = self.resblock11(x)
|
||||
x = self.resblock12(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock13(x)
|
||||
x = self.resblock14(x)
|
||||
x = self.resblock15(x)
|
||||
x = self.resblock16(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock17(x)
|
||||
x = self.resblock18(x)
|
||||
x = self.resblock19(x)
|
||||
x = self.resblock20(x)
|
||||
x = x + residual
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
# ここにreluを入れないほうがいい気がする
|
||||
|
||||
x = self.conv_final(x)
|
||||
|
||||
# network estimates the difference between the input and the output
|
||||
x = x + inp
|
||||
|
||||
return x
|
||||
|
||||
def support_latents(self) -> bool:
|
||||
return False
|
||||
|
||||
def upscale(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
lowreso_images: List[Image.Image],
|
||||
lowreso_latents: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
width: int,
|
||||
height: int,
|
||||
batch_size: int = 1,
|
||||
vae_batch_size: int = 1,
|
||||
):
|
||||
# assertion
|
||||
assert lowreso_images is not None, "Upscaler requires lowreso image"
|
||||
|
||||
# make upsampled image with lanczos4
|
||||
upsampled_images = []
|
||||
for lowreso_image in lowreso_images:
|
||||
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
|
||||
upsampled_images.append(upsampled_image)
|
||||
|
||||
# convert to tensor: this tensor is too large to be converted to cuda
|
||||
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
|
||||
upsampled_images = torch.stack(upsampled_images, dim=0)
|
||||
upsampled_images = upsampled_images.to(dtype)
|
||||
|
||||
# normalize to [-1, 1]
|
||||
upsampled_images = upsampled_images / 127.5 - 1.0
|
||||
|
||||
# convert upsample images to latents with batch size
|
||||
# print("Encoding upsampled (LANCZOS4) images...")
|
||||
upsampled_latents = []
|
||||
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
||||
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
||||
with torch.no_grad():
|
||||
batch = vae.encode(batch).latent_dist.sample()
|
||||
upsampled_latents.append(batch)
|
||||
|
||||
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
||||
|
||||
# upscale (refine) latents with this model with batch size
|
||||
print("Upscaling latents...")
|
||||
upscaled_latents = []
|
||||
for i in range(0, upsampled_latents.shape[0], batch_size):
|
||||
with torch.no_grad():
|
||||
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
|
||||
upscaled_latents = torch.cat(upscaled_latents, dim=0)
|
||||
|
||||
return upscaled_latents * 0.18215
|
||||
|
||||
|
||||
# external interface: returns a model
|
||||
def create_upscaler(**kwargs):
|
||||
weights = kwargs["weights"]
|
||||
model = Upscaler()
|
||||
|
||||
print(f"Loading weights from {weights}...")
|
||||
if os.path.splitext(weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
sd = load_file(weights)
|
||||
else:
|
||||
sd = torch.load(weights, map_location=torch.device("cpu"))
|
||||
model.load_state_dict(sd)
|
||||
return model
|
||||
|
||||
|
||||
# another interface: upscale images with a model for given images from command line
|
||||
def upscale_images(args: argparse.Namespace):
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
us_dtype = torch.float16 # TODO: support fp32/bf16
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# load VAE with Diffusers
|
||||
assert args.vae_path is not None, "VAE path is required"
|
||||
print(f"Loading VAE from {args.vae_path}...")
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||
vae.to(DEVICE, dtype=us_dtype)
|
||||
|
||||
# prepare model
|
||||
print("Preparing model...")
|
||||
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
||||
# print("Loading weights from", args.weights)
|
||||
# upscaler.load_state_dict(torch.load(args.weights))
|
||||
upscaler.eval()
|
||||
upscaler.to(DEVICE, dtype=us_dtype)
|
||||
|
||||
# load images
|
||||
image_paths = glob.glob(args.image_pattern)
|
||||
images = []
|
||||
for image_path in image_paths:
|
||||
image = Image.open(image_path)
|
||||
image = image.convert("RGB")
|
||||
|
||||
# make divisible by 8
|
||||
width = image.width
|
||||
height = image.height
|
||||
if width % 8 != 0:
|
||||
width = width - (width % 8)
|
||||
if height % 8 != 0:
|
||||
height = height - (height % 8)
|
||||
if width != image.width or height != image.height:
|
||||
image = image.crop((0, 0, width, height))
|
||||
|
||||
images.append(image)
|
||||
|
||||
# debug output
|
||||
if args.debug:
|
||||
for image, image_path in zip(images, image_paths):
|
||||
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
|
||||
|
||||
basename = os.path.basename(image_path)
|
||||
basename_wo_ext, ext = os.path.splitext(basename)
|
||||
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
|
||||
image_debug.save(dest_file_name)
|
||||
|
||||
# upscale
|
||||
print("Upscaling...")
|
||||
upscaled_latents = upscaler.upscale(
|
||||
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
||||
)
|
||||
upscaled_latents /= 0.18215
|
||||
|
||||
# decode with batch
|
||||
print("Decoding...")
|
||||
upscaled_images = []
|
||||
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
||||
with torch.no_grad():
|
||||
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
|
||||
batch = batch.to("cpu")
|
||||
upscaled_images.append(batch)
|
||||
upscaled_images = torch.cat(upscaled_images, dim=0)
|
||||
|
||||
# tensor to numpy
|
||||
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
|
||||
upscaled_images = (upscaled_images + 1.0) * 127.5
|
||||
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
|
||||
|
||||
upscaled_images = upscaled_images[..., ::-1]
|
||||
|
||||
# save images
|
||||
for i, image in enumerate(upscaled_images):
|
||||
basename = os.path.basename(image_paths[i])
|
||||
basename_wo_ext, ext = os.path.splitext(basename)
|
||||
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
|
||||
cv2.imwrite(dest_file_name, image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
|
||||
parser.add_argument("--weights", type=str, default=None, help="Weights path")
|
||||
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
|
||||
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
||||
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
|
||||
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
upscale_images(args)
|
||||
@@ -2,7 +2,7 @@ __ドキュメント更新中のため記述に誤りがあるかもしれませ
|
||||
|
||||
# 学習について、共通編
|
||||
|
||||
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
|
||||
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversion([XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)を含む)の学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
|
||||
|
||||
# 概要
|
||||
|
||||
@@ -535,7 +535,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
- `--debug_dataset`
|
||||
|
||||
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。
|
||||
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。
|
||||
|
||||
※Linux環境(Colabを含む)では画像は表示されません。
|
||||
|
||||
@@ -545,6 +545,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
|
||||
|
||||
- `--cache_latents`
|
||||
|
||||
使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。
|
||||
|
||||
- `--min_snr_gamma`
|
||||
|
||||
Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。
|
||||
|
||||
## オプティマイザ関係
|
||||
|
||||
@@ -570,7 +577,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
学習率のスケジューラ関連の指定です。
|
||||
|
||||
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。
|
||||
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, 任意のスケジューラから選べます。デフォルトはconstantです。
|
||||
|
||||
lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。
|
||||
|
||||
@@ -578,6 +585,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
詳細については各自お調べください。
|
||||
|
||||
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
|
||||
|
||||
### オプティマイザの指定について
|
||||
|
||||
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
||||
@@ -801,7 +810,7 @@ model_dirオプションでモデルの保存先フォルダを指定できま
|
||||
キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。
|
||||
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_apth <教師データフォルダ>
|
||||
python merge_captions_to_metadata.py --full_path <教師データフォルダ>
|
||||
--in_json <読み込むメタデータファイル名> <メタデータファイル名>
|
||||
```
|
||||
|
||||
|
||||
900
train_README-zh.md
Normal file
900
train_README-zh.md
Normal file
@@ -0,0 +1,900 @@
|
||||
__由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
# 关于本学习文档,通用描述
|
||||
本库支持模型微调(fine tuning)、DreamBooth、训练LoRA和文本反转(Textual Inversion)(包括[XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)
|
||||
)
|
||||
本文档将说明它们通用的学习数据准备方法和选项等。
|
||||
|
||||
# 概要
|
||||
|
||||
请提前参考本仓库的README,准备好环境。
|
||||
|
||||
|
||||
以下本节说明。
|
||||
|
||||
1. 关于准备学习数据的新形式(使用设置文件)
|
||||
1. 对于在学习中使用的术语的简要解释
|
||||
1. 先前的指定格式(不使用设置文件,而是从命令行指定)
|
||||
1. 生成学习过程中的示例图像
|
||||
1. 各脚本中常用的共同选项
|
||||
1. 准备 fine tuning 方法的元数据:如说明文字(打标签)等
|
||||
|
||||
|
||||
1. 如果只执行一次,学习就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
|
||||
|
||||
|
||||
|
||||
# 关于准备训练数据
|
||||
|
||||
在任意文件夹(也可以是多个文件夹)中准备好训练数据的图像文件。支持 `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` 格式的文件。通常不需要进行任何预处理,如调整大小等。
|
||||
|
||||
但是请勿使用极小的图像,其尺寸比训练分辨率(稍后将提到)还小,建议事先使用超分辨率AI等进行放大。另外,请注意不要使用过大的图像(约为3000 x 3000像素以上),因为这可能会导致错误,建议事先缩小。
|
||||
|
||||
在训练时,需要整理要用于训练模型的图像数据,并将其指定给脚本。根据训练数据的数量、训练目标和说明(图像描述)是否可用等因素,可以使用几种方法指定训练数据。以下是其中的一些方法(每个名称都不是通用的,而是该存储库自定义的定义)。有关正则化图像的信息将在稍后提供。
|
||||
|
||||
1. DreamBooth、class + identifier方式(可使用正则化图像)
|
||||
|
||||
将训练目标与特定单词(identifier)相关联进行训练。无需准备说明。例如,当要学习特定角色时,由于无需准备说明,因此比较方便,但由于学习数据的所有元素都与identifier相关联,例如发型、服装、背景等,因此在生成时可能会出现无法更换服装的情况。
|
||||
|
||||
2. DreamBooth、说明方式(可使用正则化图像)
|
||||
|
||||
准备记录每个图像说明的文本文件进行训练。例如,通过将图像详细信息(如穿着白色衣服的角色A、穿着红色衣服的角色A等)记录在说明中,可以将角色和其他元素分离,并期望模型更准确地学习角色。
|
||||
|
||||
3. 微调方式(不可使用正则化图像)
|
||||
|
||||
先将说明收集到元数据文件中。支持分离标签和说明以及预先缓存latents等功能,以加速训练(这些将在另一篇文档中介绍)。(虽然名为fine tuning方式,但不仅限于fine tuning。)
|
||||
你要学的东西和你可以使用的规范方法的组合如下。
|
||||
|
||||
| 学习对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
|
||||
|----------------| ----- | ----- | ----- | ----- |
|
||||
| fine tuning微调模型 | `fine_tune.py`| x | x | o |
|
||||
| DreamBooth训练模型 | `train_db.py`| o | o | x |
|
||||
| LoRA | `train_network.py`| o | o | o |
|
||||
| Textual Invesion | `train_textual_inversion.py`| o | o | o |
|
||||
|
||||
## 选择哪一个
|
||||
|
||||
如果您想要学习LoRA、Textual Inversion而不需要准备简介文件,则建议使用DreamBooth class+identifier。如果您能够准备好,则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用规则化图像,则请考虑使用fine-tuning方法。
|
||||
|
||||
对于DreamBooth也是一样的,但不能使用fine-tuning方法。对于fine-tuning方法,只能使用fine-tuning方式。
|
||||
|
||||
# 每种方法的指定方式
|
||||
|
||||
在这里,我们只介绍每种指定方法的典型模式。有关更详细的指定方法,请参见[数据集设置](./config_README-ja.md)。
|
||||
|
||||
# DreamBooth,class+identifier方法(可使用规则化图像)
|
||||
|
||||
在该方法中,每个图像将被视为使用与 `class identifier` 相同的标题进行训练(例如 `shs dog`)。
|
||||
|
||||
这样一来,每张图片都相当于使用标题“分类标识”(例如“shs dog”)进行训练。
|
||||
|
||||
## step 1.确定identifier和class
|
||||
|
||||
要将学习的目标与identifier和属于该目标的class相关联。
|
||||
|
||||
(虽然有很多称呼,但暂时按照原始论文的说法。)
|
||||
|
||||
以下是简要说明(请查阅详细信息)。
|
||||
|
||||
class是学习目标的一般类别。例如,如果要学习特定品种的狗,则class将是“dog”。对于动漫角色,根据模型不同,可能是“boy”或“girl”,也可能是“1boy”或“1girl”。
|
||||
|
||||
identifier是用于识别学习目标并进行学习的单词。可以使用任何单词,但是根据原始论文,“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
|
||||
|
||||
使用identifier和class,例如,“shs dog”可以将模型训练为从class中识别并学习所需的目标。
|
||||
|
||||
在图像生成时,使用“shs dog”将生成所学习狗种的图像。
|
||||
|
||||
(作为identifier,我最近使用的一些参考是“shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny”等。最好是不包含在Danbooru标签中的单词。)
|
||||
|
||||
## step 2. 决定是否使用正则化图像,并生成正则化图像
|
||||
|
||||
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为学习目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
|
||||
|
||||
通过同时学习目标图像和正则化图像,类别仍然保持不变,仅在将标识符附加到提示中时才生成目标图像。
|
||||
|
||||
如果您只想在LoRA或DreamBooth中使用特定的角色,则可以不使用正则化图像。
|
||||
|
||||
在Textual Inversion中也不需要使用(如果要学习的token string不包含在标题中,则不会学习任何内容)。
|
||||
|
||||
一般情况下,使用在训练目标模型时只使用类别名称生成的图像作为正则化图像是常见的做法(例如 `1girl`)。但是,如果生成的图像质量不佳,可以尝试修改提示或使用从网络上另外下载的图像。
|
||||
|
||||
(由于正则化图像也被训练,因此其质量会影响模型。)
|
||||
|
||||
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法推广并学习它们的特征)。
|
||||
|
||||
如果要使用生成的图像,请将其大小通常与训练分辨率(更准确地说是bucket的分辨率)相适应。
|
||||
|
||||
## step 2. 设置文件的描述
|
||||
|
||||
创建一个文本文件,并将其扩展名更改为`.toml`。例如,您可以按以下方式进行描述:
|
||||
|
||||
(以`#`开头的部分是注释,因此您可以直接复制粘贴,或者将其删除,都没有问题。)
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true # 是否使用Aspect Ratio Bucketing
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 学习分辨率
|
||||
batch_size = 4 # 批量大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
|
||||
class_tokens = 'hoge girl' # 指定标识符类
|
||||
num_repeats = 10 # 训练图像的迭代次数
|
||||
|
||||
# 以下仅在使用正则化图像时进行描述。不使用则删除
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # 指定类别
|
||||
num_repeats = 1 # 正则化图像的迭代次数,基本上1就可以了
|
||||
```
|
||||
|
||||
基本上只需更改以下位置即可进行学习。
|
||||
|
||||
1. 学习分辨率
|
||||
|
||||
指定一个数字表示正方形(如果是 `512`,则为 512x512),如果使用方括号和逗号分隔的两个数字,则表示横向×纵向(如果是`[512,768]`,则为 512x768)。在SD1.x系列中,原始学习分辨率为512。指定较大的分辨率,如 `[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中,分辨率为 `768`。
|
||||
|
||||
1. 批量大小
|
||||
|
||||
指定同时学习多少个数据。这取决于GPU的VRAM大小和学习分辨率。详细信息将在后面说明。此外,fine tuning/DreamBooth/LoRA等也会影响批量大小,请查看各个脚本的说明。
|
||||
|
||||
1. 文件夹指定
|
||||
|
||||
指定用于学习的图像和正则化图像(仅在使用时)的文件夹。指定包含图像数据的文件夹。
|
||||
|
||||
1. identifier 和 class 的指定
|
||||
|
||||
如前所述,与示例相同。
|
||||
|
||||
1. 迭代次数
|
||||
|
||||
将在后面说明。
|
||||
|
||||
### 关于重复次数
|
||||
|
||||
重复次数用于调整正则化图像和训练用图像的数量。由于正则化图像的数量多于训练用图像,因此需要重复使用训练用图像来达到一对一的比例,从而实现训练。
|
||||
|
||||
请将重复次数指定为“ __训练用图像的重复次数×训练用图像的数量≥正则化图像的重复次数×正则化图像的数量__ ”。
|
||||
|
||||
(1个epoch(数据一周一次)的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值,则剩余的正则化图像将不会被使用。)
|
||||
|
||||
## 步骤 3. 学习
|
||||
|
||||
请根据每个文档的参考进行学习。
|
||||
|
||||
# DreamBooth,标题方式(可使用规范化图像)
|
||||
|
||||
在此方式中,每个图像都将通过标题进行学习。
|
||||
|
||||
## 步骤 1. 准备标题文件
|
||||
|
||||
请将与图像具有相同文件名且扩展名为 `.caption`(可以在设置中更改)的文件放置在用于训练图像的文件夹中。每个文件应该只有一行。编码为 `UTF-8`。
|
||||
|
||||
## 步骤 2. 决定是否使用规范化图像,并在使用时生成规范化图像
|
||||
|
||||
与class+identifier格式相同。可以在规范化图像上附加标题,但通常不需要。
|
||||
|
||||
## 步骤 2. 编写设置文件
|
||||
|
||||
创建一个文本文件并将扩展名更改为 `.toml`。例如,可以按以下方式进行记录。
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 学習解像度
|
||||
batch_size = 4 # 批量大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
|
||||
caption_extension = '.caption' # 使用字幕文件扩展名 .txt 时重写
|
||||
num_repeats = 10 # 训练图像的迭代次数
|
||||
|
||||
# 以下仅在使用正则化图像时进行描述。不使用则删除
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg' #指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # class を指定
|
||||
num_repeats = 1 #
|
||||
正则化图像的迭代次数,基本上1就可以了
|
||||
```
|
||||
|
||||
基本上,您可以通过仅重写以下位置来学习。除非另有说明,否则与类+标识符方法相同。
|
||||
|
||||
1. 学习分辨率
|
||||
2. 批量大小
|
||||
3. 文件夹指定
|
||||
4. 标题文件的扩展名
|
||||
|
||||
可以指定任意的扩展名。
|
||||
5. 重复次数
|
||||
|
||||
## 步骤 3. 学习
|
||||
|
||||
请参考每个文档进行学习。
|
||||
|
||||
# 微调方法
|
||||
|
||||
## 步骤 1. 准备元数据
|
||||
|
||||
将标题和标签整合到管理文件中称为元数据。它的扩展名为 `.json`,格式为json。由于创建方法较长,因此在本文档的末尾进行了描述。
|
||||
|
||||
## 步骤 2. 编写设置文件
|
||||
|
||||
创建一个文本文件,将扩展名设置为 `.toml`。例如,可以按以下方式编写:
|
||||
```toml
|
||||
[general]
|
||||
shuffle_caption = true
|
||||
keep_tokens = 1
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 图像分辨率
|
||||
batch_size = 4 # 批量大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\piyo' # 指定包含训练图像的文件夹
|
||||
metadata_file = 'C:\piyo\piyo_md.json' # 元数据文件名
|
||||
```
|
||||
|
||||
基本上,您可以通过仅重写以下位置来学习。如无特别说明,与DreamBooth相同,类+标识符方式。
|
||||
|
||||
1. 学习解像度
|
||||
2. 批次大小
|
||||
3. 指定文件夹
|
||||
4. 元数据文件名
|
||||
|
||||
指定使用后面所述方法创建的元数据文件。
|
||||
|
||||
|
||||
## 第三步:学习
|
||||
|
||||
请参考各个文档进行学习。
|
||||
|
||||
# 学习中使用的术语简单解释
|
||||
|
||||
由于省略了细节并且我自己也没有完全理解,因此请自行查阅详细信息。
|
||||
|
||||
## 微调(fine tuning)
|
||||
|
||||
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中,狭义的微调是指使用图像和标题进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
|
||||
|
||||
## 步骤(step)
|
||||
|
||||
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的标题传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
|
||||
|
||||
## 批次大小(batch size)
|
||||
|
||||
批次大小指定每个步骤要计算多少数据。批量计算可以提高速度。一般来说,批次大小越大,精度也越高。
|
||||
|
||||
“批次大小×步数”是用于训练的数据数量。因此,建议减少步数以增加批次大小。
|
||||
|
||||
(但是,例如,“批次大小为 1,步数为 1600”和“批次大小为 4,步数为 400”将不会产生相同的结果。如果使用相同的学习速率,通常后者会导致模型欠拟合。请尝试增加学习率(例如 `2e-6`),将步数设置为 500 等。)
|
||||
|
||||
批次大小越大,GPU 内存消耗就越大。如果内存不足,将导致错误,或者在边缘时将导致训练速度降低。建议在任务管理器或 `nvidia-smi` 命令中检查使用的内存量进行调整。
|
||||
|
||||
另外,批次是指“一块数据”的意思。
|
||||
|
||||
## 学习率
|
||||
|
||||
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,也可能无法达到最佳状态。
|
||||
|
||||
在fine tuning、DreamBooth、LoRA等过程中,学习率会有很大的差异,并且也会受到训练数据、所需训练的模型、批量大小和步骤数等因素的影响。建议从一般的值开始,观察训练状态并逐渐调整。
|
||||
|
||||
默认情况下,整个训练过程中学习率是固定的。但是可以通过调度程序指定学习率如何变化,因此结果也会有所不同。
|
||||
|
||||
## 时代(epoch)
|
||||
|
||||
Epoch指的是训练数据被完整训练一遍(即数据一周)的情况。如果指定了重复次数,则在重复后的数据一周后,就是1个epoch。
|
||||
|
||||
1个epoch的步骤数通常为“数据量÷批量大小”,但如果使用Aspect Ratio Bucketing,则略微增加(由于不同bucket的数据不能在同一个批次中,因此步骤数会增加)。
|
||||
|
||||
## 纵横比分桶(Aspect Ratio Bucketing)
|
||||
|
||||
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分,期望更准确地学习图像和标题之间的关系。
|
||||
|
||||
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的纵横比。
|
||||
|
||||
该设置在配置中有效,可以切换,但在此之前的配置文件示例中已启用(设置为 `true`)。
|
||||
|
||||
学习分辨率将根据参数所提供的分辨率面积(即内存使用量)进行调整,以64像素为单位(默认值,可更改)在纵横方向上进行调整和创建。
|
||||
|
||||
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照纵横比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
|
||||
|
||||
# 以前的指定格式(不使用 .toml 文件,而是使用命令行选项指定)
|
||||
|
||||
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth 标题方法、微调方法三种方式。
|
||||
|
||||
## DreamBooth、类+标识符方式
|
||||
|
||||
指定文件夹名称以指定迭代次数。还要使用 `train_data_dir` 和 `reg_data_dir` 选项。
|
||||
|
||||
### 第1步。准备用于训练的图像
|
||||
|
||||
创建一个用于存储训练图像的文件夹。__此外__,按以下名称创建目录。
|
||||
|
||||
```
|
||||
<迭代次数>_<标识符> <类别>
|
||||
```
|
||||
|
||||
不要忘记下划线``_``。
|
||||
|
||||
例如,如果在名为“sls frog”的提示下重复数据 20 次,则为“20_sls frog”。如下所示:
|
||||
|
||||

|
||||
|
||||
### 多个类别、多个标识符的学习
|
||||
|
||||
该方法很简单,在用于训练的图像文件夹中,需要准备多个文件夹,每个文件夹都是以“重复次数_<标识符> <类别>”命名的,同样,在正则化图像文件夹中,也需要准备多个文件夹,每个文件夹都是以“重复次数_<类别>”命名的。
|
||||
|
||||
例如,如果要同时训练“sls青蛙”和“cpc兔子”,则应按以下方式准备文件夹。
|
||||
|
||||

|
||||
|
||||
如果一个类别包含多个对象,可以只使用一个正则化图像文件夹。例如,如果在1girl类别中有角色A和角色B,则可以按照以下方式处理:
|
||||
|
||||
- train_girls
|
||||
- 10_sls 1girl
|
||||
- 10_cpc 1girl
|
||||
- reg_girls
|
||||
- 1_1girl
|
||||
|
||||
### step 2. 准备正规化图像
|
||||
|
||||
这是使用规则化图像时的过程。
|
||||
|
||||
创建一个文件夹来存储规则化的图像。 __此外,__ 创建一个名为``<repeat count>_<class>`` 的目录。
|
||||
|
||||
例如,使用提示“frog”并且不重复数据(仅一次):
|
||||

|
||||
|
||||
|
||||
步骤3. 执行学习
|
||||
|
||||
执行每个学习脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
|
||||
|
||||
## DreamBooth,带标题方式
|
||||
|
||||
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption(可以使用选项进行更改)放置在该文件夹中,然后从该文件中加载标题作为提示进行学习。
|
||||
|
||||
※文件夹名称(标识符类)不再用于这些图像的训练。
|
||||
|
||||
默认的标题文件扩展名为.caption。可以使用学习脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行学习时会对学习时的标题进行混洗。
|
||||
|
||||
## 微调方式
|
||||
|
||||
创建元数据的方式与使用配置文件相同。 使用 `in_json` 选项指定元数据文件。
|
||||
|
||||
# 学习过程中的样本输出
|
||||
|
||||
通过在训练中使用模型生成图像,可以检查学习进度。将以下选项指定为学习脚本。
|
||||
|
||||
- `--sample_every_n_steps` / `--sample_every_n_epochs`
|
||||
|
||||
指定要采样的步数或纪元数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
|
||||
- `--sample_prompts`
|
||||
|
||||
指定示例输出的提示文件。
|
||||
|
||||
- `--sample_sampler`
|
||||
|
||||
指定用于采样输出的采样器。
|
||||
`'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が選べます。
|
||||
|
||||
要输出样本,您需要提前准备一个包含提示的文本文件。每行输入一个提示。
|
||||
|
||||
```txt
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
以“#”开头的行是注释。您可以使用“`--` + 小写字母”为生成的图像指定选项,例如 `--n`。您可以使用:
|
||||
|
||||
- `--n` 否定提示到下一个选项。
|
||||
- `--w` 指定生成图像的宽度。
|
||||
- `--h` 指定生成图像的高度。
|
||||
- `--d` 指定生成图像的种子。
|
||||
- `--l` 指定生成图像的 CFG 比例。
|
||||
- `--s` 指定生成过程中的步骤数。
|
||||
|
||||
|
||||
# 每个脚本通用的常用选项
|
||||
|
||||
文档更新可能跟不上脚本更新。在这种情况下,请使用 `--help` 选项检查可用选项。
|
||||
## 学习模型规范
|
||||
|
||||
- `--v2` / `--v_parameterization`
|
||||
|
||||
如果使用 Hugging Face 的 stable-diffusion-2-base 或来自它的微调模型作为学习目标模型(对于在推理时指示使用 `v2-inference.yaml` 的模型),`- 当使用-v2` 选项与 stable-diffusion-2、768-v-ema.ckpt 及其微调模型(对于在推理过程中使用 `v2-inference-v.yaml` 的模型),`- 指定两个 -v2`和 `--v_parameterization` 选项。
|
||||
|
||||
以下几点在 Stable Diffusion 2.0 中发生了显着变化。
|
||||
|
||||
1. 使用分词器
|
||||
2. 使用哪个Text Encoder,使用哪个输出层(2.0使用倒数第二层)
|
||||
3. Text Encoder的输出维度(768->1024)
|
||||
4. U-Net的结构(CrossAttention的头数等)
|
||||
5. v-parameterization(采样方式好像变了)
|
||||
|
||||
其中碱基使用1-4个,非碱基使用1-5个(768-v)。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
|
||||
-`--pretrained_model_name_or_path`
|
||||
|
||||
指定要从中执行额外训练的模型。您可以指定稳定扩散检查点文件(.ckpt 或 .safetensors)、扩散器本地磁盘上的模型目录或扩散器模型 ID(例如“stabilityai/stable-diffusion-2”)。
|
||||
## 学习设置
|
||||
|
||||
- `--output_dir`
|
||||
|
||||
指定训练后保存模型的文件夹。
|
||||
|
||||
- `--output_name`
|
||||
|
||||
指定不带扩展名的模型文件名。
|
||||
|
||||
- `--dataset_config`
|
||||
|
||||
指定描述数据集配置的 .toml 文件。
|
||||
|
||||
- `--max_train_steps` / `--max_train_epochs`
|
||||
|
||||
指定要学习的步数或纪元数。如果两者都指定,则 epoch 数优先。
|
||||
-
|
||||
- `--mixed_precision`
|
||||
|
||||
训练混合精度以节省内存。指定像`--mixed_precision = "fp16"`。与无混合精度(默认)相比,精度可能较低,但训练所需的 GPU 内存明显较少。
|
||||
|
||||
(在RTX30系列以后也可以指定`bf16`,请配合您在搭建环境时做的加速设置)。
|
||||
- `--gradient_checkpointing`
|
||||
|
||||
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批量大小,所以那里有影响。
|
||||
|
||||
另外,打开它通常会减慢速度,但可以增加批量大小,因此总的学习时间实际上可能会更快。
|
||||
|
||||
- `--xformers` / `--mem_eff_attn`
|
||||
|
||||
当指定 xformers 选项时,使用 xformers 的 CrossAttention。如果未安装 xformers 或发生错误(取决于环境,例如 `mixed_precision="no"`),请指定 `mem_eff_attn` 选项而不是使用 CrossAttention 的内存节省版本(xformers 比 慢)。
|
||||
- `--save_precision`
|
||||
|
||||
指定保存时的数据精度。为 save_precision 选项指定 float、fp16 或 bf16 将以该格式保存模型(在 DreamBooth 中保存 Diffusers 格式时无效,微调)。当您想缩小模型的尺寸时请使用它。
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
为 save_every_n_epochs 选项指定一个数字可以在每个时期的训练期间保存模型。
|
||||
|
||||
如果同时指定save_state选项,学习状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
|
||||
|
||||
学习状态输出到目标文件夹中名为“<output_name>-??????-state”(??????是纪元数)的文件夹中。长时间学习时请使用。
|
||||
|
||||
使用 resume 选项从保存的训练状态恢复训练。指定学习状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
|
||||
|
||||
请注意,由于 Accelerator 规范,epoch 数和全局步数不会保存,即使恢复时它们也从 1 开始。
|
||||
- `--save_model_as` (DreamBooth, fine tuning 仅有的)
|
||||
|
||||
您可以从 `ckpt, safetensors, diffusers, diffusers_safetensors` 中选择模型保存格式。
|
||||
|
||||
- `--save_model_as=safetensors` 指定喜欢当读取稳定扩散格式(ckpt 或安全张量)并以扩散器格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
|
||||
|
||||
- `--clip_skip`
|
||||
|
||||
`2` 如果指定,则使用文本编码器 (CLIP) 的倒数第二层的输出。如果省略 1 或选项,则使用最后一层。
|
||||
|
||||
*SD2.0默认使用倒数第二层,学习SD2.0时请不要指定。
|
||||
|
||||
如果被训练的模型最初被训练为使用第二层,则 2 是一个很好的值。
|
||||
|
||||
如果您使用的是最后一层,那么整个模型都会根据该假设进行训练。因此,如果再次使用第二层进行训练,可能需要一定数量的teacher数据和更长时间的学习才能得到想要的学习结果。
|
||||
- `--max_token_length`
|
||||
|
||||
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来学习。使用长字幕学习时指定。
|
||||
|
||||
但由于学习时token展开的规范与Automatic1111的web UI(除法等规范)略有不同,如非必要建议用75学习。
|
||||
|
||||
与clip_skip一样,学习与模型学习状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
|
||||
|
||||
- `--persistent_data_loader_workers`
|
||||
|
||||
在 Windows 环境中指定它可以显着减少时期之间的延迟。
|
||||
|
||||
- `--max_data_loader_n_workers`
|
||||
|
||||
指定数据加载的进程数。大量的进程会更快地加载数据并更有效地使用 GPU,但会消耗更多的主内存。默认是"`8`或者`CPU并发执行线程数 - 1`,取小者",所以如果主存没有空间或者GPU使用率大概在90%以上,就看那些数字和 `2` 或将其降低到大约 `1`。
|
||||
- `--logging_dir` / `--log_prefix`
|
||||
|
||||
保存学习日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
|
||||
|
||||
例如,如果您指定 --logging_dir=logs,将在您的工作文件夹中创建一个日志文件夹,并将日志保存在日期/时间文件夹中。
|
||||
此外,如果您指定 --log_prefix 选项,则指定的字符串将添加到日期和时间之前。使用“--logging_dir=logs --log_prefix=db_style1_”进行识别。
|
||||
|
||||
要检查 TensorBoard 中的日志,请打开另一个命令提示符并在您的工作文件夹中键入:
|
||||
```
|
||||
tensorboard --logdir=logs
|
||||
```
|
||||
|
||||
我觉得tensorboard会在环境搭建的时候安装,如果没有安装,请用`pip install tensorboard`安装。)
|
||||
|
||||
然后打开浏览器到http://localhost:6006/就可以看到了。
|
||||
- `--noise_offset`
|
||||
本文的实现:https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
|
||||
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 学习也有效。指定一个大约 0.1 的值似乎很好。
|
||||
|
||||
- `--debug_dataset`
|
||||
|
||||
通过添加此选项,您可以在学习之前检查将学习什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个纪元。
|
||||
|
||||
*图片在 Linux 环境(包括 Colab)下不显示。
|
||||
|
||||
- `--vae`
|
||||
|
||||
如果您在 vae 选项中指定稳定扩散检查点、VAE 检查点文件、扩散模型或 VAE(两者都可以指定本地或拥抱面模型 ID),则该 VAE 用于学习(缓存时的潜伏)或在学习过程中获得潜伏)。
|
||||
|
||||
对于 DreamBooth 和微调,保存的模型将包含此 VAE
|
||||
|
||||
- `--cache_latents`
|
||||
|
||||
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体学习速度略快。
|
||||
- `--min_snr_gamma`
|
||||
|
||||
指定最小 SNR 加权策略。细节是[这里](https://github.com/kohya-ss/sd-scripts/pull/308)请参阅。论文中推荐`5`。
|
||||
|
||||
## 优化器相关
|
||||
|
||||
- `--optimizer_type`
|
||||
-- 指定优化器类型。您可以指定
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 与过去版本中未指定选项时相同
|
||||
- AdamW8bit : 同上
|
||||
- 与过去版本中指定的 --use_8bit_adam 相同
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- 与过去版本中指定的 --use_lion_optimizer 相同
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 引数同上
|
||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任何优化器
|
||||
|
||||
- `--learning_rate`
|
||||
|
||||
指定学习率。合适的学习率取决于学习脚本,所以请参考每个解释。
|
||||
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
|
||||
|
||||
学习率的调度程序相关规范。
|
||||
|
||||
使用 lr_scheduler 选项,您可以从线性、余弦、cosine_with_restarts、多项式、常数、constant_with_warmup 或任何调度程序中选择学习率调度程序。默认值是常量。
|
||||
|
||||
使用 lr_warmup_steps,您可以指定预热调度程序的步数(逐渐改变学习率)。
|
||||
|
||||
lr_scheduler_num_cycles 是 cosine with restarts 调度器中的重启次数,lr_scheduler_power 是多项式调度器中的多项式幂。
|
||||
|
||||
有关详细信息,请自行研究。
|
||||
|
||||
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
|
||||
### 关于指定优化器
|
||||
|
||||
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。
|
||||
|
||||
指定可选参数时,请检查每个优化器的规格。
|
||||
一些优化器有一个必需的参数,如果省略它会自动添加(例如 SGDNesterov 的动量)。检查控制台输出。
|
||||
|
||||
D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是学习率本身,而是D-Adaptation决定的学习率的应用率,所以通常指定1.0。如果您希望 Text Encoder 的学习率是 U-Net 的一半,请指定 ``--text_encoder_lr=0.5 --unet_lr=1.0``。
|
||||
如果指定 relative_step=True,AdaFactor 优化器可以自动调整学习率(如果省略,将默认添加)。自动调整时,学习率调度器被迫使用 adafactor_scheduler。此外,指定 scale_parameter 和 warmup_init 似乎也不错。
|
||||
|
||||
自动调整的选项类似于``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"``。
|
||||
|
||||
如果您不想自动调整学习率,请添加可选参数 ``relative_step=False``。在那种情况下,似乎建议将 constant_with_warmup 用于学习率调度程序,而不要为梯度剪裁范数。所以参数就像``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0``。
|
||||
|
||||
### 使用任何优化器
|
||||
|
||||
使用 ``torch.optim`` 优化器时,仅指定类名(例如 ``--optimizer_type=RMSprop``),使用其他模块的优化器时,指定“模块名.类名”。(例如``--optimizer_type=bitsandbytes.optim.lamb.LAMB``)。
|
||||
|
||||
(内部仅通过 importlib 未确认操作。如果需要,请安装包。)
|
||||
<!--
|
||||
## 使用任意大小的图像进行训练 --resolution
|
||||
你可以在广场外学习。请在分辨率中指定“宽度、高度”,如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
|
||||
|
||||
就我个人而言,我经常生成垂直长的图像,所以我有时会用“448、640”来学习。
|
||||
|
||||
## 纵横比分桶 --enable_bucket / --min_bucket_reso / --max_bucket_reso
|
||||
它通过指定 enable_bucket 选项来启用。 Stable Diffusion 在 512x512 分辨率下训练,但也在 256x768 和 384x640 等分辨率下训练。
|
||||
|
||||
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下学习。
|
||||
由于分辨率为 64 像素,纵横比可能与原始图像不完全相同。
|
||||
|
||||
您可以使用 min_bucket_reso 选项指定分辨率的最小大小,使用 max_bucket_reso 指定最大大小。默认值分别为 256 和 1024。
|
||||
例如,将最小尺寸指定为 384 将不会使用 256x1024 或 320x768 等分辨率。
|
||||
如果将分辨率增加到 768x768,您可能需要将 1280 指定为最大尺寸。
|
||||
|
||||
启用 Aspect Ratio Ratio Bucketing 时,最好准备具有与训练图像相似的各种分辨率的正则化图像。
|
||||
|
||||
(因为一批中的图像不偏向于训练图像和正则化图像。
|
||||
|
||||
## 扩充 --color_aug / --flip_aug
|
||||
增强是一种通过在学习过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时学习。
|
||||
|
||||
由于数据是动态变化的,因此不能与 cache_latents 选项一起指定。
|
||||
|
||||
## 使用 fp16 梯度训练(实验特征)--full_fp16
|
||||
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并学习(它似乎是 full fp16 学习而不是混合精度)。
|
||||
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的情况下学习,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下学习。
|
||||
|
||||
预先在加速配置中指定 fp16,并可选择设置 ``mixed_precision="fp16"``(bf16 不起作用)。
|
||||
|
||||
为了最大限度地减少内存使用,请使用 xformers、use_8bit_adam、cache_latents、gradient_checkpointing 选项并将 train_batch_size 设置为 1。
|
||||
|
||||
(如果你负担得起,逐步增加 train_batch_size 应该会提高一点精度。)
|
||||
|
||||
它是通过修补 PyTorch 源代码实现的(已通过 PyTorch 1.12.1 和 1.13.0 确认)。准确率会大幅下降,途中学习失败的概率也会增加。
|
||||
学习率和步数的设置似乎很严格。请注意它们并自行承担使用它们的风险。
|
||||
-->
|
||||
|
||||
# 创建元数据文件
|
||||
|
||||
## 准备教师资料
|
||||
|
||||
如上所述准备好你要学习的图像数据,放在任意文件夹中。
|
||||
|
||||
例如,存储这样的图像:
|
||||
|
||||

|
||||
|
||||
## 自动字幕
|
||||
|
||||
如果您只想学习没有标题的标签,请跳过。
|
||||
|
||||
另外,手动准备字幕时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
|
||||
### 使用 BLIP 添加字幕
|
||||
|
||||
最新版本不再需要 BLIP 下载、权重下载和额外的虚拟环境。按原样工作。
|
||||
|
||||
运行 finetune 文件夹中的 make_captions.py。
|
||||
|
||||
```
|
||||
python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ>
|
||||
```
|
||||
|
||||
如果batch size为8,训练数据放在父文件夹train_data中,则会如下所示
|
||||
```
|
||||
python finetune\make_captions.py --batch_size 8 ..\train_data
|
||||
```
|
||||
|
||||
字幕文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
|
||||
|
||||
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快(我认为 12GB 的 VRAM 可以多一点)。
|
||||
您可以使用 max_length 选项指定标题的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
|
||||
您可以使用 caption_extension 选项更改标题扩展名。默认为 .caption(.txt 与稍后描述的 DeepDanbooru 冲突)。
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
请注意,推理是随机的,因此每次运行时结果都会发生变化。如果要修复它,请使用 --seed 选项指定一个随机数种子,例如 `--seed 42`。
|
||||
|
||||
其他的选项,请参考help with `--help`(好像没有文档说明参数的含义,得看源码)。
|
||||
|
||||
默认情况下,会生成扩展名为 .caption 的字幕文件。
|
||||
|
||||

|
||||
|
||||
例如,标题如下:
|
||||
|
||||

|
||||
|
||||
## 由 DeepDanbooru 标记
|
||||
|
||||
如果不想给danbooru标签本身打标签,请继续“标题和标签信息的预处理”。
|
||||
|
||||
标记是使用 DeepDanbooru 或 WD14Tagger 完成的。 WD14Tagger 似乎更准确。如果您想使用 WD14Tagger 进行标记,请跳至下一章。
|
||||
### 环境布置
|
||||
|
||||
将 DeepDanbooru https://github.com/KichangKim/DeepDanbooru 克隆到您的工作文件夹中,或下载并展开 zip。我解压缩了它。
|
||||
另外,从 DeepDanbooru 发布页面 https://github.com/KichangKim/DeepDanbooru/releases 上的“DeepDanbooru 预训练模型 v3-20211112-sgd-e28”的资产下载 deepdanbooru-v3-20211112-sgd-e28.zip 并解压到 DeepDanbooru 文件夹。
|
||||
|
||||
从下面下载。单击以打开资产并从那里下载。
|
||||
|
||||

|
||||
|
||||
做一个这样的目录结构
|
||||
|
||||

|
||||
为扩散器环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io)。
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
接下来,安装 DeepDanbooru 本身。
|
||||
|
||||
```
|
||||
pip install .
|
||||
```
|
||||
|
||||
这样就完成了标注环境的准备工作。
|
||||
|
||||
### 实施标记
|
||||
转到 DeepDanbooru 的文件夹并运行 deepdanbooru 进行标记。
|
||||
```
|
||||
deepdanbooru evaluate <教师资料夹> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
||||
```
|
||||
|
||||
如果将训练数据放在父文件夹train_data中,则如下所示。
|
||||
```
|
||||
deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
||||
```
|
||||
|
||||
在与教师数据图像相同的目录中创建具有相同文件名和扩展名.txt 的标记文件。它很慢,因为它是一个接一个地处理的。
|
||||
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
它生成如下。
|
||||
|
||||

|
||||
|
||||
它会被这样标记(信息量很大...)。
|
||||
|
||||

|
||||
|
||||
## WD14Tagger标记为
|
||||
|
||||
此过程使用 WD14Tagger 而不是 DeepDanbooru。
|
||||
|
||||
使用 Mr. Automatic1111 的 WebUI 中使用的标记器。我参考了这个 github 页面上的信息 (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger)。
|
||||
|
||||
初始环境维护所需的模块已经安装。权重自动从 Hugging Face 下载。
|
||||
### 实施标记
|
||||
|
||||
运行脚本以进行标记。
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ>
|
||||
```
|
||||
|
||||
如果将训练数据放在父文件夹train_data中,则如下所示
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
|
||||
```
|
||||
|
||||
模型文件将在首次启动时自动下载到 wd14_tagger_model 文件夹(文件夹可以在选项中更改)。它将如下所示。
|
||||

|
||||
|
||||
在与教师数据图像相同的目录中创建具有相同文件名和扩展名.txt 的标记文件。
|
||||

|
||||
|
||||

|
||||
|
||||
使用 thresh 选项,您可以指定确定的标签的置信度数以附加标签。默认值为 0.35,与 WD14Tagger 示例相同。较低的值给出更多的标签,但准确性较低。
|
||||
|
||||
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快(我认为 12GB 的 VRAM 可以多一点)。您可以使用 caption_extension 选项更改标记文件扩展名。默认为 .txt。
|
||||
|
||||
您可以使用 model_dir 选项指定保存模型的文件夹。
|
||||
|
||||
此外,如果指定 force_download 选项,即使有保存目标文件夹,也会重新下载模型。
|
||||
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
## 预处理字幕和标签信息
|
||||
|
||||
将字幕和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
|
||||
### 字幕预处理
|
||||
|
||||
要将字幕放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用字幕进行学习,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path <教师资料夹>
|
||||
--in_json <要读取的元数据文件名> <元数据文件名>
|
||||
```
|
||||
|
||||
元数据文件名是任意名称。
|
||||
如果训练数据为train_data,没有读取元数据文件,元数据文件为meta_cap.json,则会如下。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path train_data meta_cap.json
|
||||
```
|
||||
|
||||
您可以使用 caption_extension 选项指定标题扩展。
|
||||
|
||||
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path
|
||||
train_data1 meta_cap1.json
|
||||
python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
|
||||
train_data2 meta_cap2.json
|
||||
```
|
||||
如果省略in_json,如果有写入目标元数据文件,将从那里读取并覆盖。
|
||||
|
||||
__* 每次重写 in_json 选项和写入目标并写入单独的元数据文件是安全的。 __
|
||||
### 标签预处理
|
||||
|
||||
同样,标签也收集在元数据中(如果标签不用于学习,则无需这样做)。
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path <教师资料夹>
|
||||
--in_json <要读取的元数据文件名> <要写入的元数据文件名>
|
||||
```
|
||||
|
||||
同样的目录结构,读取meta_cap.json和写入meta_cap_dd.json时,会是这样的。
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json
|
||||
```
|
||||
|
||||
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行。
|
||||
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
|
||||
train_data1 meta_cap_dd1.json
|
||||
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
|
||||
train_data2 meta_cap_dd2.json
|
||||
```
|
||||
|
||||
如果省略in_json,如果有写入目标元数据文件,将从那里读取并覆盖。
|
||||
__※ 通过每次重写 in_json 选项和写入目标,写入单独的元数据文件是安全的。 __
|
||||
### 标题和标签清理
|
||||
|
||||
到目前为止,标题和DeepDanbooru标签已经被整理到元数据文件中。然而,自动标题生成的标题存在表达差异等微妙问题(※),而标签中可能包含下划线和评级(DeepDanbooru的情况下)。因此,最好使用编辑器的替换功能清理标题和标签。
|
||||
|
||||
※例如,如果要学习动漫中的女孩,标题可能会包含girl/girls/woman/women等不同的表达方式。另外,将"anime girl"简单地替换为"girl"可能更合适。
|
||||
|
||||
我们提供了用于清理的脚本,请根据情况编辑脚本并使用它。
|
||||
|
||||
(不需要指定教师数据文件夹。将清理元数据中的所有数据。)
|
||||
|
||||
```
|
||||
python clean_captions_and_tags.py <要读取的元数据文件名> <要写入的元数据文件名>
|
||||
```
|
||||
|
||||
--in_json 请注意,不包括在内。例如:
|
||||
|
||||
```
|
||||
python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
|
||||
```
|
||||
|
||||
标题和标签的预处理现已完成。
|
||||
|
||||
## 预先获取 latents
|
||||
|
||||
※ 这一步骤并非必须。即使省略此步骤,也可以在训练过程中获取 latents。但是,如果在训练时执行 `random_crop` 或 `color_aug` 等操作,则无法预先获取 latents(因为每次图像都会改变)。如果不进行预先获取,则可以使用到目前为止的元数据进行训练。
|
||||
|
||||
提前获取图像的潜在表达并保存到磁盘上。这样可以加速训练过程。同时进行 bucketing(根据宽高比对训练数据进行分类)。
|
||||
|
||||
请在工作文件夹中输入以下内容。
|
||||
|
||||
```
|
||||
python prepare_buckets_latents.py --full_path <教师资料夹>
|
||||
<要读取的元数据文件名> <要写入的元数据文件名>
|
||||
<要微调的模型名称或检查点>
|
||||
--batch_size <批量大小>
|
||||
--max_resolution <分辨率宽、高>
|
||||
--mixed_precision <准确性>
|
||||
```
|
||||
|
||||
如果要从meta_clean.json中读取元数据,并将其写入meta_lat.json,使用模型model.ckpt,批处理大小为4,训练分辨率为512*512,精度为no(float32),则应如下所示。
|
||||
```
|
||||
python prepare_buckets_latents.py --full_path
|
||||
train_data meta_clean.json meta_lat.json model.ckpt
|
||||
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
||||
```
|
||||
|
||||
教师数据文件夹中,latents以numpy的npz格式保存。
|
||||
|
||||
您可以使用--min_bucket_reso选项指定最小分辨率大小,--max_bucket_reso指定最大大小。默认值分别为256和1024。例如,如果指定最小大小为384,则将不再使用分辨率为256 * 1024或320 * 768等。如果将分辨率增加到768 * 768等较大的值,则最好将最大大小指定为1280等。
|
||||
|
||||
如果指定--flip_aug选项,则进行左右翻转的数据增强。虽然这可以使数据量伪造一倍,但如果数据不是左右对称的(例如角色外观、发型等),则可能会导致训练不成功。
|
||||
|
||||
对于翻转的图像,也会获取latents,并保存名为\ *_flip.npz的文件,这是一个简单的实现。在fline_tune.py中不需要特定的选项。如果有带有\_flip的文件,则会随机加载带有和不带有flip的文件。
|
||||
|
||||
即使VRAM为12GB,批量大小也可以稍微增加。分辨率以“宽度,高度”的形式指定,必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中,512,512似乎是极限(*)。如果有16GB,则可以将其提高到512,704或512,768。即使分辨率为256,256等,VRAM 8GB也很难承受(因为参数、优化器等与分辨率无关,需要一定的内存)。
|
||||
|
||||
*有报道称,在batch size为1的训练中,使用12GB VRAM和640,640的分辨率。
|
||||
|
||||
以下是bucketing结果的显示方式。
|
||||
|
||||

|
||||
|
||||
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行
|
||||
|
||||
```
|
||||
python prepare_buckets_latents.py --full_path
|
||||
train_data1 meta_clean.json meta_lat1.json model.ckpt
|
||||
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
||||
|
||||
python prepare_buckets_latents.py --full_path
|
||||
train_data2 meta_lat1.json meta_lat2.json model.ckpt
|
||||
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
||||
|
||||
```
|
||||
可以将读取源和写入目标设为相同,但分开设定更为安全。
|
||||
|
||||
__※建议每次更改参数并将其写入另一个元数据文件,以确保安全性。__
|
||||
86
train_db.py
86
train_db.py
@@ -23,7 +23,7 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -118,12 +118,14 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
@@ -202,9 +204,7 @@ def train(args):
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -232,7 +232,7 @@ def train(args):
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth")
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
@@ -273,10 +273,20 @@ def train(args):
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@@ -327,6 +337,27 @@ def train(args):
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -356,21 +387,24 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end(
|
||||
args,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
# checking for saving is in util
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -381,7 +415,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
162
train_db_README-zh.md
Normal file
162
train_db_README-zh.md
Normal file
@@ -0,0 +1,162 @@
|
||||
这是DreamBooth的指南。
|
||||
|
||||
请同时查看[关于学习的通用文档](./train_README-zh.md)。
|
||||
|
||||
# 概要
|
||||
|
||||
DreamBooth是一种将特定主题添加到图像生成模型中进行学习,并使用特定识别子生成它的技术。论文链接。
|
||||
|
||||
具体来说,它可以将角色和绘画风格等添加到Stable Diffusion模型中进行学习,并使用特定的单词(例如`shs`)来调用(呈现在生成的图像中)。
|
||||
|
||||
脚本基于Diffusers的DreamBooth,但添加了以下功能(一些功能已在原始脚本中得到支持)。
|
||||
|
||||
脚本的主要功能如下:
|
||||
|
||||
- 使用8位Adam优化器和潜在变量的缓存来节省内存(与Shivam Shrirao版相似)。
|
||||
- 使用xformers来节省内存。
|
||||
- 不仅支持512x512,还支持任意尺寸的训练。
|
||||
- 通过数据增强来提高质量。
|
||||
- 支持DreamBooth和Text Encoder + U-Net的微调。
|
||||
- 支持以Stable Diffusion格式读写模型。
|
||||
- 支持Aspect Ratio Bucketing。
|
||||
- 支持Stable Diffusion v2.0。
|
||||
|
||||
# 训练步骤
|
||||
|
||||
请先参阅此存储库的README以进行环境设置。
|
||||
|
||||
## 准备数据
|
||||
|
||||
请参阅[有关准备训练数据的说明](./train_README-zh.md)。
|
||||
|
||||
## 运行训练
|
||||
|
||||
运行脚本。以下是最大程度地节省内存的命令(实际上,这将在一行中输入)。请根据需要修改每行。它似乎需要约12GB的VRAM才能运行。
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录>
|
||||
--dataset_config=<数据准备时创建的.toml文件>
|
||||
--output_dir=<训练模型的输出目录>
|
||||
--output_name=<训练模型输出时的文件名>
|
||||
--save_model_as=safetensors
|
||||
--prior_loss_weight=1.0
|
||||
--max_train_steps=1600
|
||||
--learning_rate=1e-6
|
||||
--optimizer_type="AdamW8bit"
|
||||
--xformers
|
||||
--mixed_precision="fp16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
```
|
||||
`num_cpu_threads_per_process` 通常应该设置为1。
|
||||
|
||||
`pretrained_model_name_or_path` 指定要进行追加训练的基础模型。可以指定 Stable Diffusion 的 checkpoint 文件(.ckpt 或 .safetensors)、Diffusers 的本地模型目录或模型 ID(如 "stabilityai/stable-diffusion-2")。
|
||||
|
||||
`output_dir` 指定保存训练后模型的文件夹。在 `output_name` 中指定模型文件名,不包括扩展名。使用 `save_model_as` 指定以 safetensors 格式保存。
|
||||
|
||||
在 `dataset_config` 中指定 `.toml` 文件。初始批处理大小应为 `1`,以减少内存消耗。
|
||||
|
||||
`prior_loss_weight` 是正则化图像损失的权重。通常设为1.0。
|
||||
|
||||
将要训练的步数 `max_train_steps` 设置为1600。在这里,学习率 `learning_rate` 被设置为1e-6。
|
||||
|
||||
为了节省内存,设置 `mixed_precision="fp16"`(在 RTX30 系列及更高版本中也可以设置为 `bf16`)。同时指定 `gradient_checkpointing`。
|
||||
|
||||
为了使用内存消耗较少的 8bit AdamW 优化器(将模型优化为适合于训练数据的状态),指定 `optimizer_type="AdamW8bit"`。
|
||||
|
||||
指定 `xformers` 选项,并使用 xformers 的 CrossAttention。如果未安装 xformers 或出现错误(具体情况取决于环境,例如使用 `mixed_precision="no"`),则可以指定 `mem_eff_attn` 选项以使用省内存版的 CrossAttention(速度会变慢)。
|
||||
|
||||
为了节省内存,指定 `cache_latents` 选项以缓存 VAE 的输出。
|
||||
|
||||
如果有足够的内存,请编辑 `.toml` 文件将批处理大小增加到大约 `4`(可能会提高速度和精度)。此外,取消 `cache_latents` 选项可以进行数据增强。
|
||||
|
||||
### 常用选项
|
||||
|
||||
对于以下情况,请参阅“常用选项”部分。
|
||||
|
||||
- 学习 Stable Diffusion 2.x 或其衍生模型。
|
||||
- 学习基于 clip skip 大于等于2的模型。
|
||||
- 学习超过75个令牌的标题。
|
||||
|
||||
### 关于DreamBooth中的步数
|
||||
|
||||
为了实现省内存化,该脚本中每个步骤的学习次数减半(因为学习和正则化的图像在训练时被分为不同的批次)。
|
||||
|
||||
要进行与原始Diffusers版或XavierXiao的Stable Diffusion版几乎相同的学习,请将步骤数加倍。
|
||||
|
||||
(虽然在将学习图像和正则化图像整合后再打乱顺序,但我认为对学习没有太大影响。)
|
||||
|
||||
关于DreamBooth的批量大小
|
||||
|
||||
与像LoRA这样的学习相比,为了训练整个模型,内存消耗量会更大(与微调相同)。
|
||||
|
||||
关于学习率
|
||||
|
||||
在Diffusers版中,学习率为5e-6,而在Stable Diffusion版中为1e-6,因此在上面的示例中指定了1e-6。
|
||||
|
||||
当使用旧格式的数据集指定命令行时
|
||||
|
||||
使用选项指定分辨率和批量大小。命令行示例如下。
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录>
|
||||
--train_data_dir=<训练数据的目录>
|
||||
--reg_data_dir=<正则化图像的目录>
|
||||
--output_dir=<训练后模型的输出目录>
|
||||
--output_name=<训练后模型输出文件的名称>
|
||||
--prior_loss_weight=1.0
|
||||
--resolution=512
|
||||
--train_batch_size=1
|
||||
--learning_rate=1e-6
|
||||
--max_train_steps=1600
|
||||
--use_8bit_adam
|
||||
--xformers
|
||||
--mixed_precision="bf16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
## 使用训练好的模型生成图像
|
||||
|
||||
训练完成后,将在指定的文件夹中以指定的名称输出safetensors文件。
|
||||
|
||||
对于v1.4/1.5和其他派生模型,可以在此模型中使用Automatic1111先生的WebUI进行推断。请将其放置在models\Stable-diffusion文件夹中。
|
||||
|
||||
对于使用v2.x模型在WebUI中生成图像的情况,需要单独的.yaml文件来描述模型的规格。对于v2.x base,需要v2-inference.yaml,对于768/v,则需要v2-inference-v.yaml。请将它们放置在相同的文件夹中,并将文件扩展名之前的部分命名为与模型相同的名称。
|
||||

|
||||
|
||||
每个yaml文件都在[Stability AI的SD2.0存储库](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)……之中。
|
||||
|
||||
# DreamBooth的其他主要选项
|
||||
|
||||
有关所有选项的详细信息,请参阅另一份文档。
|
||||
|
||||
## 不在中途开始对文本编码器进行训练 --stop_text_encoder_training
|
||||
|
||||
如果在stop_text_encoder_training选项中指定一个数字,则在该步骤之后,将不再对文本编码器进行训练,只会对U-Net进行训练。在某些情况下,可能会期望提高精度。
|
||||
|
||||
(我们推测可能会有时候仅仅文本编码器会过度学习,而这样做可以避免这种情况,但详细影响尚不清楚。)
|
||||
|
||||
## 不进行分词器的填充 --no_token_padding
|
||||
|
||||
如果指定no_token_padding选项,则不会对分词器的输出进行填充(与Diffusers版本的旧DreamBooth相同)。
|
||||
|
||||
<!--
|
||||
如果使用分桶(bucketing)和数据增强(augmentation),则使用示例如下:
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录>
|
||||
--train_data_dir=<训练数据的目录>
|
||||
--reg_data_dir=<正则化图像的目录>
|
||||
--output_dir=<训练后模型的输出目录>
|
||||
--resolution=768,512
|
||||
--train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800
|
||||
--use_8bit_adam --xformers --mixed_precision="bf16"
|
||||
--save_every_n_epochs=1 --save_state --save_precision="bf16"
|
||||
--logging_dir=logs
|
||||
--enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280
|
||||
--color_aug --flip_aug --gradient_checkpointing --seed 42
|
||||
```
|
||||
|
||||
|
||||
-->
|
||||
189
train_network.py
189
train_network.py
@@ -24,24 +24,40 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
@@ -56,8 +72,9 @@ def train(args):
|
||||
use_dreambooth_method = args.in_json is None
|
||||
use_user_config = args.dataset_config is not None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
@@ -99,10 +116,10 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
@@ -146,7 +163,6 @@ def train(args):
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@@ -156,12 +172,14 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare network
|
||||
import sys
|
||||
|
||||
@@ -180,14 +198,17 @@ def train(args):
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if args.network_weights is not None:
|
||||
print("load network weights from:", args.network_weights)
|
||||
network.load_weights(args.network_weights)
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
@@ -196,7 +217,15 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
# 後方互換性を確保するよ
|
||||
try:
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
except TypeError:
|
||||
print(
|
||||
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||
)
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
@@ -214,7 +243,9 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
if is_main_process:
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
@@ -283,9 +314,7 @@ def train(args):
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -346,6 +375,7 @@ def train(args):
|
||||
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||
}
|
||||
|
||||
if use_user_config:
|
||||
@@ -474,8 +504,6 @@ def train(args):
|
||||
# add extra args
|
||||
if args.network_args:
|
||||
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
||||
# for key, value in net_kwargs.items():
|
||||
# metadata["ss_arg_" + key] = value
|
||||
|
||||
# model name and hash
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
@@ -510,15 +538,42 @@ def train(args):
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("network_train")
|
||||
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
del train_dataset_group
|
||||
|
||||
# if hasattr(network, "on_step_start"):
|
||||
# on_step_start = network.on_step_start
|
||||
# else:
|
||||
# on_step_start = lambda *args, **kwargs: None
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
metadata["ss_steps"] = str(steps)
|
||||
metadata["ss_epoch"] = str(epoch_no)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch+1
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
@@ -527,6 +582,8 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
# on_step_start(text_encoder, unet)
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
@@ -538,9 +595,18 @@ def train(args):
|
||||
|
||||
with torch.set_grad_enabled(train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
@@ -593,6 +659,21 @@ def train(args):
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, unwrap_model(network), global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
@@ -617,33 +698,26 @@ def train(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
if args.save_every_n_epochs is not None:
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1)
|
||||
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
if is_main_process:
|
||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
metadata["ss_epoch"] = str(num_train_epochs)
|
||||
# metadata["ss_epoch"] = str(num_train_epochs)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
|
||||
if is_main_process:
|
||||
@@ -651,20 +725,15 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
if is_main_process and args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = model_name + "." + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,31 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
|
||||
|
||||
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
||||
|
||||
# 学習できるLoRAの種類
|
||||
|
||||
以下の二種類をサポートします。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear およびカーネルサイズ 1x1 の Conv2d に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d に適用されるLoRA
|
||||
|
||||
LoRA-LierLaに比べ、LoRA-C3Liarは適用される層が増える分、高い精度が期待できるかもしれません。
|
||||
|
||||
また学習時は __DyLoRA__ を使用することもできます(後述します)。
|
||||
|
||||
## 学習したモデルに関する注意
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
LoRA-LierLa は、AUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
LoRA-C3Liarを使いWeb UIで生成するには、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
|
||||
いずれも学習したLoRAのモデルを、Stable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージすることもできます。
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
|
||||
# 学習の手順
|
||||
|
||||
@@ -31,9 +51,9 @@ WebUI等で画像生成する場合には、学習したLoRAのモデルを学
|
||||
|
||||
`train_network.py`を用います。
|
||||
|
||||
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
||||
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのは`network.lora`となりますので、それを指定してください。
|
||||
|
||||
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
|
||||
なお学習率は通常のDreamBoothやfine tuningよりも高めの、`1e-4`~`1e-3`程度を指定するとよいようです。
|
||||
|
||||
以下はコマンドラインの例です。
|
||||
|
||||
@@ -56,6 +76,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
このコマンドラインでは LoRA-LierLa が学習されます。
|
||||
|
||||
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
|
||||
|
||||
その他、以下のオプションが指定できます。
|
||||
@@ -83,22 +105,143 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
|
||||
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
|
||||
## LoRA を Conv2d に拡大して適用する
|
||||
# その他の学習方法
|
||||
|
||||
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
||||
## LoRA-C3Lier を学習する
|
||||
|
||||
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
|
||||
|
||||
```
|
||||
--network_args "conv_dim=1" "conv_alpha=1"
|
||||
--network_args "conv_dim=4" "conv_alpha=1"
|
||||
```
|
||||
|
||||
以下のように alpha 省略時は1になります。
|
||||
|
||||
```
|
||||
--network_args "conv_dim=1"
|
||||
--network_args "conv_dim=4"
|
||||
```
|
||||
|
||||
## DyLoRA
|
||||
|
||||
DyLoRAはこちらの論文で提案されたものです。[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558) 公式実装は[こちら](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)です。
|
||||
|
||||
論文によると、LoRAのrankは必ずしも高いほうが良いわけではなく、対象のモデル、データセット、タスクなどにより適切なrankを探す必要があるようです。DyLoRAを使うと、指定したdim(rank)以下のさまざまなrankで同時にLoRAを学習します。これにより最適なrankをそれぞれ学習して探す手間を省くことができます。
|
||||
|
||||
当リポジトリの実装は公式実装をベースに独自の拡張を加えています(そのため不具合などあるかもしれません)。
|
||||
|
||||
### 当リポジトリのDyLoRAの特徴
|
||||
|
||||
学習後のDyLoRAのモデルファイルはLoRAと互換性があります。また、モデルファイルから指定したdim(rank)以下の複数のdimのLoRAを抽出できます。
|
||||
|
||||
DyLoRA-LierLa、DyLoRA-C3Lierのどちらも学習できます。
|
||||
|
||||
### DyLoRAで学習する
|
||||
|
||||
`--network_module=networks.dylora` のように、DyLoRAに対応する`network.dylora`を指定してください。
|
||||
|
||||
また `--network_args` に、たとえば`--network_args "unit=4"`のように`unit`を指定します。`unit`はrankを分割する単位です。たとえば`--network_dim=16 --network_args "unit=4"` のように指定します。`unit`は`network_dim`を割り切れる値(`network_dim`は`unit`の倍数)としてください。
|
||||
|
||||
`unit`を指定しない場合は、`unit=1`として扱われます。
|
||||
|
||||
記述例は以下です。
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4"
|
||||
```
|
||||
|
||||
DyLoRA-C3Lierの場合は、`--network_args` に`"conv_dim=4"`のように`conv_dim`を指定します。通常のLoRAと異なり、`conv_dim`は`network_dim`と同じ値である必要があります。記述例は以下です。
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8"
|
||||
```
|
||||
|
||||
たとえばdim=16、unit=4(後述)で学習すると、4、8、12、16の4つのrankのLoRAを学習、抽出できます。抽出した各モデルで画像を生成し、比較することで、最適なrankのLoRAを選択できます。
|
||||
|
||||
その他のオプションは通常のLoRAと同じです。
|
||||
|
||||
※ `unit`は当リポジトリの独自拡張で、DyLoRAでは同dim(rank)の通常LoRAに比べると学習時間が長くなることが予想されるため、分割単位を大きくしたものです。
|
||||
|
||||
### DyLoRAのモデルからLoRAモデルを抽出する
|
||||
|
||||
`networks`フォルダ内の `extract_lora_from_dylora.py`を使用します。指定した`unit`単位で、DyLoRAのモデルからLoRAのモデルを抽出します。
|
||||
|
||||
コマンドラインはたとえば以下のようになります。
|
||||
|
||||
```powershell
|
||||
python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4
|
||||
```
|
||||
|
||||
`--model` にはDyLoRAのモデルファイルを指定します。`--save_to` には抽出したモデルを保存するファイル名を指定します(rankの数値がファイル名に付加されます)。`--unit` にはDyLoRAの学習時の`unit`を指定します。
|
||||
|
||||
## 階層別学習率
|
||||
|
||||
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
|
||||
|
||||
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
|
||||
|
||||
`--network_args` で以下の引数を指定してください。
|
||||
|
||||
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
|
||||
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
|
||||
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。
|
||||
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
|
||||
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
|
||||
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
|
||||
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
|
||||
|
||||
### 階層別学習率コマンドライン指定例:
|
||||
|
||||
```powershell
|
||||
--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"
|
||||
|
||||
--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"
|
||||
```
|
||||
|
||||
### 階層別学習率tomlファイル指定例:
|
||||
|
||||
```toml
|
||||
network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]
|
||||
|
||||
network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]
|
||||
```
|
||||
|
||||
## 階層別dim (rank)
|
||||
|
||||
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
|
||||
|
||||
`--network_args` で以下の引数を指定してください。
|
||||
|
||||
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
|
||||
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。
|
||||
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
|
||||
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
|
||||
|
||||
### 階層別dim (rank)コマンドライン指定例:
|
||||
|
||||
```powershell
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
```
|
||||
|
||||
### 階層別dim (rank)tomlファイル指定例:
|
||||
|
||||
```toml
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]
|
||||
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]
|
||||
```
|
||||
|
||||
# その他のスクリプト
|
||||
|
||||
マージ等LoRAに関連するスクリプト群です。
|
||||
|
||||
## マージスクリプトについて
|
||||
|
||||
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
|
||||
@@ -188,6 +331,73 @@ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプショ
|
||||
|
||||
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
||||
|
||||
## Diffusersのpipelineで生成する
|
||||
|
||||
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from networks.lora import LoRAModule, create_network_from_weights
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
|
||||
|
||||
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
|
||||
device = "cuda"
|
||||
|
||||
# create pipe
|
||||
print(f"creating pipe from {model_id_or_dir}...")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
vae = pipe.vae
|
||||
text_encoder = pipe.text_encoder
|
||||
unet = pipe.unet
|
||||
|
||||
# load lora networks
|
||||
print(f"loading lora networks...")
|
||||
|
||||
lora_path1 = r"lora1.safetensors"
|
||||
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
|
||||
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network1.apply_to(text_encoder, unet)
|
||||
network1.load_state_dict(sd)
|
||||
network1.to(device, dtype=torch.float16)
|
||||
|
||||
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
|
||||
# network.merge_to(text_encoder, unet, sd)
|
||||
|
||||
lora_path2 = r"lora2.safetensors"
|
||||
sd = load_file(lora_path2)
|
||||
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
|
||||
network2.apply_to(text_encoder, unet)
|
||||
network2.load_state_dict(sd)
|
||||
network2.to(device, dtype=torch.float16)
|
||||
|
||||
lora_path3 = r"lora3.safetensors"
|
||||
sd = load_file(lora_path3)
|
||||
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network3.apply_to(text_encoder, unet)
|
||||
network3.load_state_dict(sd)
|
||||
network3.to(device, dtype=torch.float16)
|
||||
|
||||
# prompts
|
||||
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
|
||||
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
|
||||
|
||||
# exec pipe
|
||||
print("generating image...")
|
||||
with torch.autocast("cuda"):
|
||||
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
|
||||
|
||||
# if not merged, you can use set_multiplier
|
||||
# network1.set_multiplier(0.8)
|
||||
# and generate image again...
|
||||
|
||||
# save image
|
||||
image.save(r"by_diffusers..png")
|
||||
```
|
||||
|
||||
## 二つのモデルの差分からLoRAモデルを作成する
|
||||
|
||||
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
||||
@@ -256,14 +466,14 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256
|
||||
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
|
||||
|
||||
|
||||
## 追加情報
|
||||
# 追加情報
|
||||
|
||||
### cloneofsimo氏のリポジトリとの違い
|
||||
## cloneofsimo氏のリポジトリとの違い
|
||||
|
||||
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
||||
|
||||
またモジュール入れ替え機構は全く異なります。
|
||||
|
||||
### 将来拡張について
|
||||
## 将来拡張について
|
||||
|
||||
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。
|
||||
|
||||
466
train_network_README-zh.md
Normal file
466
train_network_README-zh.md
Normal file
@@ -0,0 +1,466 @@
|
||||
# 关于LoRA的学习。
|
||||
|
||||
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)这是应用于Stable Diffusion“稳定扩散”的内容。
|
||||
|
||||
[cloneofsimo先生的代码仓库](https://github.com/cloneofsimo/lora) 我们非常感謝您提供的参考。非常感謝。
|
||||
|
||||
通常情況下,LoRA只适用于Linear和Kernel大小为1x1的Conv2d,但也可以將其擴展到Kernel大小为3x3的Conv2d。
|
||||
|
||||
Conv2d 3x3的扩展最初是由 [cloneofsimo先生的代码仓库](https://github.com/cloneofsimo/lora)
|
||||
而KohakuBlueleaf先生在[LoCon](https://github.com/KohakuBlueleaf/LoCon)中揭示了其有效性。我们深深地感谢KohakuBlueleaf先生。
|
||||
|
||||
看起来即使在8GB VRAM上也可以勉强运行。
|
||||
|
||||
请同时查看关于[学习的通用文档](./train_README-zh.md)。
|
||||
# 可学习的LoRA 类型
|
||||
|
||||
支持以下两种类型。以下是本仓库中自定义的名称。
|
||||
|
||||
1. __LoRA-LierLa__:(用于 __Li__ n __e__ a __r__ __La__ yers 的 LoRA,读作 "Liela")
|
||||
|
||||
适用于 Linear 和卷积层 Conv2d 的 1x1 Kernel 的 LoRA
|
||||
|
||||
2. __LoRA-C3Lier__:(用于具有 3x3 Kernel 的卷积层和 __Li__ n __e__ a __r__ 层的 LoRA,读作 "Seria")
|
||||
|
||||
除了第一种类型外,还适用于 3x3 Kernel 的 Conv2d 的 LoRA
|
||||
|
||||
与 LoRA-LierLa 相比,LoRA-C3Lier 可能会获得更高的准确性,因为它适用于更多的层。
|
||||
|
||||
在训练时,也可以使用 __DyLoRA__(将在后面介绍)。
|
||||
|
||||
## 请注意与所学模型相关的事项。
|
||||
|
||||
LoRA-LierLa可以用于AUTOMATIC1111先生的Web UI LoRA功能。
|
||||
|
||||
要使用LoRA-C3Liar并在Web UI中生成,请使用此处的[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)。
|
||||
|
||||
在此存储库的脚本中,您还可以预先将经过训练的LoRA模型合并到Stable Diffusion模型中。
|
||||
|
||||
请注意,与cloneofsimo先生的存储库以及d8ahazard先生的[Stable-Diffusion-WebUI的Dreambooth扩展](https://github.com/d8ahazard/sd_dreambooth_extension)不兼容,因为它们进行了一些功能扩展(如下文所述)。
|
||||
|
||||
# 学习步骤
|
||||
|
||||
请先参考此存储库的README文件并进行环境设置。
|
||||
|
||||
## 准备数据
|
||||
|
||||
请参考 [关于准备学习数据](./train_README-zh.md)。
|
||||
|
||||
## 网络训练
|
||||
|
||||
使用`train_network.py`。
|
||||
|
||||
在`train_network.py`中,使用`--network_module`选项指定要训练的模块名称。对于LoRA模块,它应该是`network.lora`,请指定它。
|
||||
|
||||
请注意,学习率应该比通常的DreamBooth或fine tuning要高,建议指定为`1e-4`至`1e-3`左右。
|
||||
|
||||
以下是命令行示例。
|
||||
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型目录>
|
||||
--dataset_config=<数据集配置的.toml文件>
|
||||
--output_dir=<训练过程中的模型输出文件夹>
|
||||
--output_name=<训练模型输出时的文件名>
|
||||
--save_model_as=safetensors
|
||||
--prior_loss_weight=1.0
|
||||
--max_train_steps=400
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--xformers
|
||||
--mixed_precision="fp16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
--save_every_n_epochs=1
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
在这个命令行中,LoRA-LierLa将会被训练。
|
||||
|
||||
LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中。关于其他选项和优化器等,请参阅[学习的通用文档](./train_README-zh.md)中的“常用选项”。
|
||||
|
||||
此外,还可以指定以下选项:
|
||||
|
||||
* `--network_dim`
|
||||
* 指定LoRA的RANK(例如:`--network_dim=4`)。默认值为4。数值越大表示表现力越强,但需要更多的内存和时间来训练。而且不要盲目增加此数值。
|
||||
* `--network_alpha`
|
||||
* 指定用于防止下溢并稳定训练的alpha值。默认值为1。如果与`network_dim`指定相同的值,则将获得与以前版本相同的行为。
|
||||
* `--persistent_data_loader_workers`
|
||||
* 在Windows环境中指定可大幅缩短epoch之间的等待时间。
|
||||
* `--max_data_loader_n_workers`
|
||||
* 指定数据读取进程的数量。进程数越多,数据读取速度越快,可以更有效地利用GPU,但会占用主存。默认值为“`8`或`CPU同步执行线程数-1`的最小值”,因此如果主存不足或GPU使用率超过90%,则应将这些数字降低到约`2`或`1`。
|
||||
* `--network_weights`
|
||||
* 在训练之前读取预训练的LoRA权重,并在此基础上进行进一步的训练。
|
||||
* `--network_train_unet_only`
|
||||
* 仅启用与U-Net相关的LoRA模块。在类似fine tuning的学习中指定此选项可能会很有用。
|
||||
* `--network_train_text_encoder_only`
|
||||
* 仅启用与Text Encoder相关的LoRA模块。可能会期望Textual Inversion效果。
|
||||
* `--unet_lr`
|
||||
* 当在U-Net相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。
|
||||
* `--text_encoder_lr`
|
||||
* 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。
|
||||
* `--network_args`
|
||||
* 可以指定多个参数。将在下面详细说明。
|
||||
|
||||
当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。
|
||||
|
||||
# 其他的学习方法
|
||||
|
||||
## 学习 LoRA-C3Lier
|
||||
|
||||
请使用以下方式
|
||||
|
||||
```
|
||||
--network_args "conv_dim=4"
|
||||
```
|
||||
|
||||
DyLoRA是在这篇论文中提出的[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558),
|
||||
[其官方实现可在这里找到](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)。
|
||||
|
||||
根据论文,LoRA的rank并不是越高越好,而是需要根据模型、数据集、任务等因素来寻找合适的rank。使用DyLoRA,可以同时在指定的维度(rank)下学习多种rank的LoRA,从而省去了寻找最佳rank的麻烦。
|
||||
|
||||
本存储库的实现基于官方实现进行了自定义扩展(因此可能存在缺陷)。
|
||||
|
||||
### 本存储库DyLoRA的特点
|
||||
|
||||
DyLoRA训练后的模型文件与LoRA兼容。此外,可以从模型文件中提取多个低于指定维度(rank)的LoRA。
|
||||
|
||||
DyLoRA-LierLa和DyLoRA-C3Lier均可训练。
|
||||
|
||||
### 使用DyLoRA进行训练
|
||||
|
||||
请指定与DyLoRA相对应的`network.dylora`,例如 `--network_module=networks.dylora`。
|
||||
|
||||
此外,通过 `--network_args` 指定例如`--network_args "unit=4"`的参数。`unit`是划分rank的单位。例如,可以指定为`--network_dim=16 --network_args "unit=4"`。请将`unit`视为可以被`network_dim`整除的值(`network_dim`是`unit`的倍数)。
|
||||
|
||||
如果未指定`unit`,则默认为`unit=1`。
|
||||
|
||||
以下是示例说明。
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4"
|
||||
```
|
||||
|
||||
对于DyLoRA-C3Lier,需要在 `--network_args` 中指定 `conv_dim`,例如 `conv_dim=4`。与普通的LoRA不同,`conv_dim`必须与`network_dim`具有相同的值。以下是一个示例描述:
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8"
|
||||
```
|
||||
|
||||
例如,当使用dim=16、unit=4(如下所述)进行学习时,可以学习和提取4个rank的LoRA,即4、8、12和16。通过在每个提取的模型中生成图像并进行比较,可以选择最佳rank的LoRA。
|
||||
|
||||
其他选项与普通的LoRA相同。
|
||||
|
||||
*`unit`是本存储库的独有扩展,在DyLoRA中,由于预计相比同维度(rank)的普通LoRA,学习时间更长,因此将分割单位增加。
|
||||
|
||||
### 从DyLoRA模型中提取LoRA模型
|
||||
|
||||
请使用`networks`文件夹中的`extract_lora_from_dylora.py`。指定`unit`单位后,从DyLoRA模型中提取LoRA模型。
|
||||
|
||||
例如,命令行如下:
|
||||
|
||||
```powershell
|
||||
python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4
|
||||
```
|
||||
|
||||
`--model` 参数用于指定DyLoRA模型文件。`--save_to` 参数用于指定要保存提取的模型的文件名(rank值将附加到文件名中)。`--unit` 参数用于指定DyLoRA训练时的`unit`。
|
||||
|
||||
## 分层学习率
|
||||
|
||||
请参阅PR#355了解详细信息。
|
||||
|
||||
您可以指定完整模型的25个块的权重。虽然第一个块没有对应的LoRA,但为了与分层LoRA应用等的兼容性,将其设为25个。此外,如果不扩展到conv2d3x3,则某些块中可能不存在LoRA,但为了统一描述,请始终指定25个值。
|
||||
|
||||
请在 `--network_args` 中指定以下参数。
|
||||
|
||||
- `down_lr_weight`:指定U-Net down blocks的学习率权重。可以指定以下内容:
|
||||
- 每个块的权重:指定12个数字,例如`"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"`
|
||||
- 从预设中指定:例如`"down_lr_weight=sine"`(使用正弦曲线指定权重)。可以指定sine、cosine、linear、reverse_linear、zeros。另外,添加 `+数字` 时,可以将指定的数字加上(变为0.25〜1.25)。
|
||||
- `mid_lr_weight`:指定U-Net mid block的学习率权重。只需指定一个数字,例如 `"mid_lr_weight=0.5"`。
|
||||
- `up_lr_weight`:指定U-Net up blocks的学习率权重。与down_lr_weight相同。
|
||||
- 省略指定的部分将被视为1.0。另外,如果将权重设为0,则不会创建该块的LoRA模块。
|
||||
- `block_lr_zero_threshold`:如果权重小于此值,则不会创建LoRA模块。默认值为0。
|
||||
|
||||
### 分层学习率命令行指定示例:
|
||||
|
||||
|
||||
```powershell
|
||||
--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"
|
||||
|
||||
--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"
|
||||
```
|
||||
|
||||
### Hierarchical Learning Rate指定的toml文件示例:
|
||||
|
||||
```toml
|
||||
network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]
|
||||
|
||||
network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]
|
||||
```
|
||||
|
||||
## 层次结构维度(rank)
|
||||
|
||||
您可以指定完整模型的25个块的维度(rank)。与分层学习率一样,某些块可能不存在LoRA,但请始终指定25个值。
|
||||
|
||||
请在 `--network_args` 中指定以下参数:
|
||||
|
||||
- `block_dims`:指定每个块的维度(rank)。指定25个数字,例如 `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`。
|
||||
- `block_alphas`:指定每个块的alpha。与block_dims一样,指定25个数字。如果省略,将使用network_alpha的值。
|
||||
- `conv_block_dims`:将LoRA扩展到Conv2d 3x3,并指定每个块的维度(rank)。
|
||||
- `conv_block_alphas`:在将LoRA扩展到Conv2d 3x3时指定每个块的alpha。如果省略,将使用conv_alpha的值。
|
||||
|
||||
### 层次结构维度(rank)命令行指定示例:
|
||||
|
||||
|
||||
```powershell
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
```
|
||||
|
||||
### 层级别dim(rank) toml文件指定示例:
|
||||
|
||||
```toml
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]
|
||||
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]
|
||||
```
|
||||
|
||||
# Other scripts
|
||||
这些是与LoRA相关的脚本,如合并脚本等。
|
||||
|
||||
关于合并脚本
|
||||
您可以使用merge_lora.py脚本将LoRA的训练结果合并到稳定扩散模型中,也可以将多个LoRA模型合并。
|
||||
|
||||
合并到稳定扩散模型中的LoRA模型
|
||||
合并后的模型可以像常规的稳定扩散ckpt一样使用。例如,以下是一个命令行示例:
|
||||
|
||||
```
|
||||
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
--save_to ..\lora_train1\model-char1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors --ratios 0.8
|
||||
```
|
||||
|
||||
请使用 Stable Diffusion v2.x 模型进行训练并进行合并时,需要指定--v2选项。
|
||||
|
||||
使用--sd_model选项指定要合并的 Stable Diffusion 模型文件(仅支持 .ckpt 或 .safetensors 格式,目前不支持 Diffusers)。
|
||||
|
||||
使用--save_to选项指定合并后模型的保存路径(根据扩展名自动判断为 .ckpt 或 .safetensors)。
|
||||
|
||||
使用--models选项指定已训练的 LoRA 模型文件,也可以指定多个,然后按顺序进行合并。
|
||||
|
||||
使用--ratios选项以0~1.0的数字指定每个模型的应用率(将多大比例的权重反映到原始模型中)。例如,在接近过度拟合的情况下,降低应用率可能会使结果更好。请指定与模型数量相同的比率。
|
||||
|
||||
当指定多个模型时,格式如下:
|
||||
|
||||
|
||||
```
|
||||
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
--save_to ..\lora_train1\model-char1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5
|
||||
```
|
||||
|
||||
### 将多个LoRA模型合并
|
||||
|
||||
将多个LoRA模型逐个应用于SD模型与将多个LoRA模型合并后再应用于SD模型之间,由于计算顺序的不同,会得到微妙不同的结果。
|
||||
|
||||
例如,下面是一个命令行示例:
|
||||
|
||||
```
|
||||
python networks\merge_lora.py
|
||||
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
|
||||
```
|
||||
|
||||
--sd_model选项不需要指定。
|
||||
|
||||
通过--save_to选项指定合并后的LoRA模型的保存位置(.ckpt或.safetensors,根据扩展名自动识别)。
|
||||
|
||||
通过--models选项指定学习的LoRA模型文件。可以指定三个或更多。
|
||||
|
||||
通过--ratios选项以0~1.0的数字指定每个模型的比率(反映多少权重来自原始模型)。如果将两个模型一对一合并,则比率将是“0.5 0.5”。如果比率为“1.0 1.0”,则总重量将过大,可能会产生不理想的结果。
|
||||
|
||||
在v1和v2中学习的LoRA,以及rank(维数)或“alpha”不同的LoRA不能合并。仅包含U-Net的LoRA和包含U-Net+文本编码器的LoRA可以合并,但结果未知。
|
||||
|
||||
### 其他选项
|
||||
|
||||
* 精度
|
||||
* 可以从float、fp16或bf16中选择合并计算时的精度。默认为float以保证精度。如果想减少内存使用量,请指定fp16/bf16。
|
||||
* save_precision
|
||||
* 可以从float、fp16或bf16中选择在保存模型时的精度。默认与精度相同。
|
||||
|
||||
## 合并多个维度不同的LoRA模型
|
||||
|
||||
将多个LoRA近似为一个LoRA(无法完全复制)。使用'svd_merge_lora.py'。例如,以下是命令行的示例。
|
||||
```
|
||||
python networks\svd_merge_lora.py
|
||||
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
||||
--ratios 0.6 0.4 --new_rank 32 --device cuda
|
||||
```
|
||||
`merge_lora.py`和主要选项相同。以下选项已添加:
|
||||
|
||||
- `--new_rank`
|
||||
- 指定要创建的LoRA rank。
|
||||
- `--new_conv_rank`
|
||||
- 指定要创建的Conv2d 3x3 LoRA的rank。如果省略,则与`new_rank`相同。
|
||||
- `--device`
|
||||
- 如果指定为`--device cuda`,则在GPU上执行计算。处理速度将更快。
|
||||
|
||||
## 在此存储库中生成图像的脚本中
|
||||
|
||||
请在`gen_img_diffusers.py`中添加`--network_module`和`--network_weights`选项。其含义与训练时相同。
|
||||
|
||||
通过`--network_mul`选项,可以指定0~1.0的数字来改变LoRA的应用率。
|
||||
|
||||
## 请参考以下示例,在Diffusers的pipeline中生成。
|
||||
|
||||
所需文件仅为networks/lora.py。请注意,该示例只能在Diffusers版本0.10.2中正常运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from networks.lora import LoRAModule, create_network_from_weights
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
|
||||
|
||||
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
|
||||
device = "cuda"
|
||||
|
||||
# create pipe
|
||||
print(f"creating pipe from {model_id_or_dir}...")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
vae = pipe.vae
|
||||
text_encoder = pipe.text_encoder
|
||||
unet = pipe.unet
|
||||
|
||||
# load lora networks
|
||||
print(f"loading lora networks...")
|
||||
|
||||
lora_path1 = r"lora1.safetensors"
|
||||
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
|
||||
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network1.apply_to(text_encoder, unet)
|
||||
network1.load_state_dict(sd)
|
||||
network1.to(device, dtype=torch.float16)
|
||||
|
||||
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
|
||||
# network.merge_to(text_encoder, unet, sd)
|
||||
|
||||
lora_path2 = r"lora2.safetensors"
|
||||
sd = load_file(lora_path2)
|
||||
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
|
||||
network2.apply_to(text_encoder, unet)
|
||||
network2.load_state_dict(sd)
|
||||
network2.to(device, dtype=torch.float16)
|
||||
|
||||
lora_path3 = r"lora3.safetensors"
|
||||
sd = load_file(lora_path3)
|
||||
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network3.apply_to(text_encoder, unet)
|
||||
network3.load_state_dict(sd)
|
||||
network3.to(device, dtype=torch.float16)
|
||||
|
||||
# prompts
|
||||
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
|
||||
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
|
||||
|
||||
# exec pipe
|
||||
print("generating image...")
|
||||
with torch.autocast("cuda"):
|
||||
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
|
||||
|
||||
# if not merged, you can use set_multiplier
|
||||
# network1.set_multiplier(0.8)
|
||||
# and generate image again...
|
||||
|
||||
# save image
|
||||
image.save(r"by_diffusers..png")
|
||||
```
|
||||
|
||||
## 从两个模型的差异中创建LoRA模型。
|
||||
|
||||
[参考讨论链接](https://github.com/cloneofsimo/lora/discussions/56)這是參考實現的結果。數學公式沒有改變(我並不完全理解,但似乎使用奇異值分解進行了近似)。
|
||||
|
||||
将两个模型(例如微调原始模型和微调后的模型)的差异近似为LoRA。
|
||||
|
||||
### 脚本执行方法
|
||||
|
||||
请按以下方式指定。
|
||||
|
||||
```
|
||||
python networks\extract_lora_from_models.py --model_org base-model.ckpt
|
||||
--model_tuned fine-tuned-model.ckpt
|
||||
--save_to lora-weights.safetensors --dim 4
|
||||
```
|
||||
|
||||
--model_org 选项指定原始的Stable Diffusion模型。如果要应用创建的LoRA模型,则需要指定该模型并将其应用。可以指定.ckpt或.safetensors文件。
|
||||
|
||||
--model_tuned 选项指定要提取差分的目标Stable Diffusion模型。例如,可以指定经过Fine Tuning或DreamBooth后的模型。可以指定.ckpt或.safetensors文件。
|
||||
|
||||
--save_to 指定LoRA模型的保存路径。--dim指定LoRA的维数。
|
||||
|
||||
生成的LoRA模型可以像已训练的LoRA模型一样使用。
|
||||
|
||||
当两个模型的文本编码器相同时,LoRA将成为仅包含U-Net的LoRA。
|
||||
|
||||
### 其他选项
|
||||
|
||||
- `--v2`
|
||||
- 如果使用v2.x的稳定扩散模型,请指定此选项。
|
||||
- `--device`
|
||||
- 指定为 ``--device cuda`` 可在GPU上执行计算。这会使处理速度更快(即使在CPU上也不会太慢,大约快几倍)。
|
||||
- `--save_precision`
|
||||
- 指定LoRA的保存格式为“float”、“fp16”、“bf16”。如果省略,将使用float。
|
||||
- `--conv_dim`
|
||||
- 指定后,将扩展LoRA的应用范围到Conv2d 3x3。指定Conv2d 3x3的rank。
|
||||
-
|
||||
## 图像大小调整脚本
|
||||
|
||||
(稍后将整理文件,但现在先在这里写下说明。)
|
||||
|
||||
在 Aspect Ratio Bucketing 的功能扩展中,现在可以将小图像直接用作教师数据,而无需进行放大。我收到了一个用于前处理的脚本,其中包括将原始教师图像缩小的图像添加到教师数据中可以提高准确性的报告。我整理了这个脚本并加入了感谢 bmaltais 先生。
|
||||
|
||||
### 执行脚本的方法如下。
|
||||
原始图像以及调整大小后的图像将保存到转换目标文件夹中。调整大小后的图像将在文件名中添加“+512x512”之类的调整后的分辨率(与图像大小不同)。小于调整大小后分辨率的图像将不会被放大。
|
||||
|
||||
```
|
||||
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
|
||||
--copy_associated_files 源图像文件夹目标文件夹
|
||||
```
|
||||
|
||||
在元画像文件夹中的图像文件将被调整大小以达到指定的分辨率(可以指定多个),并保存到目标文件夹中。除图像外的文件将被保留为原样。
|
||||
|
||||
请使用“--max_resolution”选项指定调整大小后的大小,使其达到指定的面积大小。如果指定多个,则会在每个分辨率上进行调整大小。例如,“512x512,384x384,256x256”将使目标文件夹中的图像变为原始大小和调整大小后的大小×3共计4张图像。
|
||||
|
||||
如果使用“--save_as_png”选项,则会以PNG格式保存。如果省略,则默认以JPEG格式(quality=100)保存。
|
||||
|
||||
如果使用“--copy_associated_files”选项,则会将与图像相同的文件名(例如标题等)的文件复制到调整大小后的图像文件的文件名相同的位置,但不包括扩展名。
|
||||
|
||||
### 其他选项
|
||||
|
||||
- divisible_by
|
||||
- 将图像中心裁剪到能够被该值整除的大小(分别是垂直和水平的大小),以便调整大小后的图像大小可以被该值整除。
|
||||
- interpolation
|
||||
- 指定缩小时的插值方法。可从``area、cubic、lanczos4``中选择,默认为``area``。
|
||||
|
||||
|
||||
# 追加信息
|
||||
|
||||
## 与cloneofsimo的代码库的区别
|
||||
|
||||
截至2022年12月25日,本代码库将LoRA应用扩展到了Text Encoder的MLP、U-Net的FFN以及Transformer的输入/输出投影中,从而增强了表现力。但是,内存使用量增加了,接近了8GB的限制。
|
||||
|
||||
此外,模块交换机制也完全不同。
|
||||
|
||||
## 关于未来的扩展
|
||||
|
||||
除了LoRA之外,我们还计划添加其他扩展,以支持更多的功能。
|
||||
@@ -13,6 +13,7 @@ import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
@@ -184,10 +185,10 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
@@ -232,12 +233,14 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
@@ -261,7 +264,9 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
@@ -304,9 +309,7 @@ def train(args):
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -334,11 +337,28 @@ def train(args):
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion")
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch+1
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
text_encoder.train()
|
||||
|
||||
@@ -358,7 +378,7 @@ def train(args):
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||
# use float instead of fp16/bf16 because text encoder is float
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -376,7 +396,8 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
@@ -386,9 +407,9 @@ def train(args):
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
@@ -419,6 +440,23 @@ def train(args):
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -445,24 +483,18 @@ def train(args):
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if accelerator.is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
|
||||
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
@@ -476,7 +508,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
@@ -484,14 +516,9 @@ def train(args):
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = model_name + "." + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
@@ -546,7 +573,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser, False)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
|
||||
@@ -13,6 +13,7 @@ import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
@@ -266,12 +267,14 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
@@ -340,9 +343,7 @@ def train(args):
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -370,8 +371,25 @@ def train(args):
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion")
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -417,7 +435,8 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
@@ -460,6 +479,23 @@ def train(args):
|
||||
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
# )
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -486,24 +522,18 @@ def train(args):
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if accelerator.is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
|
||||
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
# TODO: fix sample_images
|
||||
# train_util.sample_images(
|
||||
@@ -518,7 +548,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
@@ -526,14 +556,9 @@ def train(args):
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = model_name + "." + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
@@ -600,7 +625,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser, False)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
|
||||
|
||||
学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。
|
||||
学習したモデルはWeb UIでもそのまま使えます。
|
||||
|
||||
# 学習の手順
|
||||
|
||||
|
||||
Reference in New Issue
Block a user