mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
25 Commits
v0.8.8
...
stable-cas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
235a1ea2c6 | ||
|
|
cb648a2bf8 | ||
|
|
3a2a48c15d | ||
|
|
40f2c688db | ||
|
|
e4f8736c60 | ||
|
|
13f49d1e4a | ||
|
|
df7648245e | ||
|
|
3368fb1af7 | ||
|
|
417f14d245 | ||
|
|
86503cb945 | ||
|
|
d91b1d3793 | ||
|
|
70917077a6 | ||
|
|
69dbc50912 | ||
|
|
985761ca43 | ||
|
|
71e03559e2 | ||
|
|
806a6237fb | ||
|
|
9b0e532942 | ||
|
|
c26f01241f | ||
|
|
ac71168939 | ||
|
|
4e37d950d2 | ||
|
|
4b5784eb44 | ||
|
|
856df07f49 | ||
|
|
80ef59c115 | ||
|
|
319bbf8057 | ||
|
|
fa440208b7 |
171
README.md
171
README.md
@@ -1,3 +1,174 @@
|
||||
# Training Stable Cascade Stage C
|
||||
|
||||
This is an experimental feature. There may be bugs.
|
||||
|
||||
__Feb 25, 2024 Update:__ Fixed a bug that the LoRA weights trained can be loaded in ComfyUI. If you still have a problem, please let me know.
|
||||
|
||||
__Feb 25, 2024 Update:__ Fixed a bug that Stage C training with mixed precision behaves the same as `--full_bf16` (fp16) regardless of `--full_bf16` (fp16) specified.
|
||||
|
||||
This is because the Stage C weights were loaded in bf16/fp16. With this fix, the memory usage without `--full_bf16` (fp16) specified will increase, so you may need to specify `--full_bf16` (fp16) as needed.
|
||||
|
||||
__Feb 22, 2024 Update:__ Fixed a bug that LoRA is not applied to some modules (to_q/k/v and to_out) in Attention. Also, the model structure of Stage C has been changed, and you can choose xformers and SDPA (SDPA was used before). Please specify `--sdpa` or `--xformers` option.
|
||||
|
||||
__Feb 20, 2024 Update:__ There was a problem with the preprocessing of the EfficientNetEncoder, and the latents became invalid (the saturation of the training results decreases). If you have cached `_sc_latents.npz` files with `--cache_latents_to_disk`, please delete them before training.
|
||||
|
||||
## Usage
|
||||
|
||||
Training is run with `stable_cascade_train_stage_c.py`.
|
||||
|
||||
The main options are the same as `sdxl_train.py`. The following options have been added.
|
||||
|
||||
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
|
||||
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
|
||||
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
|
||||
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
|
||||
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
|
||||
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
|
||||
|
||||
The learning rate is set to 1e-4 in the official settings.
|
||||
|
||||
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
|
||||
|
||||
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
|
||||
|
||||
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
|
||||
|
||||
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
|
||||
|
||||
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
|
||||
|
||||
A scale of about 4 is suitable for sample image generation.
|
||||
|
||||
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
|
||||
|
||||
The code for training the Text Encoder is also written, but it is untested.
|
||||
|
||||
### Command line sample
|
||||
|
||||
```batch
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight
|
||||
```
|
||||
|
||||
### About the dataset for fine tuning
|
||||
|
||||
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
|
||||
|
||||
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
|
||||
|
||||
## LoRA training
|
||||
|
||||
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
|
||||
|
||||
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
|
||||
|
||||
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
|
||||
|
||||
Text Encoder LoRA training is implemented, but untested.
|
||||
|
||||
## Image generation
|
||||
|
||||
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
|
||||
|
||||
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
|
||||
|
||||
The following prompt options are available.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--t` Specifies the t_start of the generation.
|
||||
* `--f` Specifies the shift of the generation.
|
||||
|
||||
# Stable Cascade Stage C の学習
|
||||
|
||||
実験的機能です。不具合があるかもしれません。
|
||||
|
||||
__2024/2/25 追記:__ 学習される LoRA の重みが ComfyUI で読み込めるよう修正しました。依然として不具合がある場合にはご連絡ください。
|
||||
|
||||
__2024/2/25 追記:__ Mixed precision 時のStage C の学習が、 `--full_bf16` (fp16) の指定に関わらず `--full_bf16` (fp16) 指定時と同じ動作となる(と思われる)不具合を修正しました。
|
||||
|
||||
Stage C の重みを bf16/fp16 で読み込んでいたためです。この修正により `--full_bf16` (fp16) 未指定時のメモリ使用量が増えますので、必要に応じて `--full_bf16` (fp16) を指定してください。
|
||||
|
||||
__2024/2/22 追記:__ LoRA が一部のモジュール(Attention の to_q/k/v および to_out)に適用されない不具合を修正しました。また Stage C のモデル構造を変更し xformers と SDPA を選べるようになりました(今までは SDPA が使用されていました)。`--sdpa` または `--xformers` オプションを指定してください。
|
||||
|
||||
__2024/2/20 追記:__ EfficientNetEncoder の前処理に不具合があり、latents が不正になっていました(学習結果の彩度が低下する現象が起きます)。`--cache_latents_to_disk` でキャッシュした `_sc_latents.npz` がある場合、いったん削除してから学習してください。
|
||||
|
||||
## 使い方
|
||||
|
||||
学習は `stable_cascade_train_stage_c.py` で行います。
|
||||
|
||||
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
|
||||
|
||||
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
|
||||
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
|
||||
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
|
||||
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
|
||||
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
|
||||
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
|
||||
|
||||
学習率は、公式の設定では 1e-4 のようです。
|
||||
|
||||
初回は `--text_model_checkpoint_path` と `--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
|
||||
|
||||
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
|
||||
|
||||
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されます(mixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
|
||||
|
||||
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
|
||||
|
||||
メモリ消費を抑えるための `--gradient_checkpointing` 、`--full_bf16`、`--full_fp16`(未テスト)はそのまま使用できます。
|
||||
|
||||
サンプル画像生成時の Scale には 4 程度が適しているようです。
|
||||
|
||||
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
|
||||
|
||||
Text Encoder 学習のコードも書いてありますが、未テストです。
|
||||
|
||||
### コマンドラインのサンプル
|
||||
|
||||
[Command-line-sample](#command-line-sample)を参照してください。
|
||||
|
||||
|
||||
### fine tuning方式のデータセットについて
|
||||
|
||||
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
|
||||
|
||||
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
|
||||
|
||||
|
||||
## LoRA 等の学習
|
||||
|
||||
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
|
||||
|
||||
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
|
||||
|
||||
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習(Pivotal Tuning)も実装されていません。
|
||||
|
||||
Text Encoder の LoRA 学習は実装してありますが、未テストです。
|
||||
|
||||
## 画像生成
|
||||
|
||||
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
|
||||
|
||||
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
|
||||
|
||||
プロンプトオプションとして以下が使用できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--t` Specifies the t_start of the generation.
|
||||
* `--f` Specifies the shift of the generation.
|
||||
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
@@ -11,15 +11,19 @@ import cv2
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
@@ -42,7 +46,7 @@ def collate_fn_remove_corrupted(batch):
|
||||
return batch
|
||||
|
||||
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
@@ -50,10 +54,11 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name) + ".npz"
|
||||
return os.path.join(data_dir, relative_path, base_name) + ext
|
||||
else:
|
||||
return os.path.join(data_dir, base_name) + ".npz"
|
||||
return os.path.join(data_dir, base_name) + ext
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -83,13 +88,20 @@ def main(args):
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
if not args.stable_cascade:
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
divisor = 8
|
||||
else:
|
||||
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
|
||||
divisor = 32
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
assert (
|
||||
len(max_reso) == 2
|
||||
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
@@ -154,6 +166,10 @@ def main(args):
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
|
||||
# 追加情報を記録
|
||||
metadata[image_key]["original_size"] = (image.width, image.height)
|
||||
metadata[image_key]["train_resized_size"] = resized_size
|
||||
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert (
|
||||
@@ -168,9 +184,9 @@ def main(args):
|
||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
|
||||
continue
|
||||
|
||||
# バッチへ追加
|
||||
@@ -208,7 +224,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument(
|
||||
"--stable_cascade",
|
||||
action="store_true",
|
||||
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
@@ -231,10 +254,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
"--bucket_no_upscale",
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="use mixed precision / 混合精度を使う場合、その精度",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
@@ -242,7 +271,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
"--flip_aug",
|
||||
action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
|
||||
@@ -6,8 +6,10 @@ import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import safetensors
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
r"""
|
||||
@@ -55,11 +57,13 @@ ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_STABLE_CASCADE = "stable-cascade"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
@@ -113,6 +117,7 @@ def build_metadata(
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
stable_cascade: Optional[bool] = None,
|
||||
):
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
@@ -124,7 +129,9 @@ def build_metadata(
|
||||
# hash = precalculate_safetensors_hashes(state_dict)
|
||||
# metadata["modelspec.hash_sha256"] = hash
|
||||
|
||||
if sdxl:
|
||||
if stable_cascade:
|
||||
arch = ARCH_STABLE_CASCADE
|
||||
elif sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
@@ -142,9 +149,11 @@ def build_metadata(
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
if stable_cascade:
|
||||
impl = IMPL_STABILITY_AI_STABLE_CASCADE
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -236,7 +245,7 @@ def build_metadata(
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -250,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]:
|
||||
def load_metadata_from_safetensors(model: str) -> dict:
|
||||
if not model.endswith(".safetensors"):
|
||||
return {}
|
||||
|
||||
|
||||
with safetensors.safe_open(model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
|
||||
1654
library/stable_cascade.py
Normal file
1654
library/stable_cascade.py
Normal file
File diff suppressed because it is too large
Load Diff
668
library/stable_cascade_utils.py
Normal file
668
library/stable_cascade_utils.py
Normal file
@@ -0,0 +1,668 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
||||
from accelerate import init_empty_weights, Accelerator, PartialState
|
||||
from PIL import Image
|
||||
|
||||
from library import stable_cascade as sc
|
||||
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.train_util import (
|
||||
save_sd_model_on_epoch_end_or_stepwise_common,
|
||||
save_sd_model_on_train_end_common,
|
||||
line_to_prompt_dict,
|
||||
get_hidden_states_stable_cascade,
|
||||
)
|
||||
from library import sai_model_spec
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
|
||||
|
||||
|
||||
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
|
||||
resolution_multiple = 42.67
|
||||
latent_height = math.ceil(height / compression_factor_b)
|
||||
latent_width = math.ceil(width / compression_factor_b)
|
||||
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
|
||||
|
||||
latent_height = math.ceil(height / compression_factor_a)
|
||||
latent_width = math.ceil(width / compression_factor_a)
|
||||
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
|
||||
|
||||
return stage_c_latent_shape, stage_b_latent_shape
|
||||
|
||||
|
||||
# region load and save
|
||||
|
||||
|
||||
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
|
||||
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
|
||||
effnet = sc.EfficientNetEncoder()
|
||||
effnet_checkpoint = load_file(effnet_checkpoint_path)
|
||||
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
|
||||
logger.info(info)
|
||||
del effnet_checkpoint
|
||||
return effnet
|
||||
|
||||
|
||||
def load_tokenizer(args: argparse.Namespace):
|
||||
# TODO commonize with sdxl_train_util.load_tokenizers
|
||||
logger.info("prepare tokenizers")
|
||||
|
||||
original_paths = [CLIP_TEXT_MODEL_NAME]
|
||||
tokenizers = []
|
||||
for i, original_path in enumerate(original_paths):
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
tokenizers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
|
||||
return tokenizers[0]
|
||||
|
||||
|
||||
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
|
||||
# Generator
|
||||
logger.info(f"Instantiating Stage C generator")
|
||||
with init_empty_weights():
|
||||
generator_c = sc.StageC()
|
||||
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
|
||||
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
|
||||
|
||||
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
|
||||
|
||||
logger.info(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
return generator_c
|
||||
|
||||
|
||||
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
|
||||
logger.info(f"Instantiating Stage B generator")
|
||||
with init_empty_weights():
|
||||
generator_b = sc.StageB()
|
||||
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
|
||||
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
|
||||
|
||||
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
|
||||
|
||||
logger.info(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
return generator_b
|
||||
|
||||
|
||||
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
|
||||
# CLIP encoders
|
||||
logger.info(f"Loading CLIP text model")
|
||||
if save_text_model or text_model_checkpoint_path is None:
|
||||
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
|
||||
|
||||
if save_text_model:
|
||||
sd = text_model.state_dict()
|
||||
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
|
||||
save_file(sd, text_model_checkpoint_path)
|
||||
else:
|
||||
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
|
||||
|
||||
# copy from sdxl_model_util.py
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
text_model_checkpoint = load_file(text_model_checkpoint_path)
|
||||
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
|
||||
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
|
||||
stage_a = sc.StageA().to(device)
|
||||
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
|
||||
info = stage_a.load_state_dict(
|
||||
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
|
||||
)
|
||||
logger.info(info)
|
||||
return stage_a
|
||||
|
||||
|
||||
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
|
||||
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
|
||||
previewer = sc.Previewer().to(device)
|
||||
previewer_checkpoint = load_file(previewer_checkpoint_path)
|
||||
info = previewer.load_state_dict(
|
||||
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
|
||||
)
|
||||
logger.info(info)
|
||||
return previewer
|
||||
|
||||
|
||||
def convert_state_dict_mha_to_normal_attn(state_dict):
|
||||
# convert nn.MultiheadAttention to to_q/k/v and out_proj
|
||||
print("convert_state_dict_mha_to_normal_attn")
|
||||
for key in list(state_dict.keys()):
|
||||
if "attention.attn." in key:
|
||||
if "in_proj_bias" in key:
|
||||
value = state_dict.pop(key)
|
||||
qkv = torch.chunk(value, 3, dim=0)
|
||||
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
|
||||
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
|
||||
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
|
||||
elif "in_proj_weight" in key:
|
||||
value = state_dict.pop(key)
|
||||
qkv = torch.chunk(value, 3, dim=0)
|
||||
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
|
||||
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
|
||||
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
|
||||
elif "out_proj.bias" in key:
|
||||
value = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
|
||||
elif "out_proj.weight" in key:
|
||||
value = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_state_dict_normal_attn_to_mha(state_dict):
|
||||
# convert to_q/k/v and out_proj to nn.MultiheadAttention
|
||||
for key in list(state_dict.keys()):
|
||||
if "attention.attn." in key:
|
||||
if "to_q.bias" in key:
|
||||
q = state_dict.pop(key)
|
||||
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
|
||||
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
|
||||
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
|
||||
elif "to_q.weight" in key:
|
||||
q = state_dict.pop(key)
|
||||
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
|
||||
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
|
||||
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
|
||||
elif "out_proj.bias" in key:
|
||||
v = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
|
||||
elif "out_proj.weight" in key:
|
||||
v = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_sai_model_spec(args, lora=False):
|
||||
timestamp = time.time()
|
||||
|
||||
reso = args.resolution
|
||||
|
||||
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
||||
|
||||
if args.min_timestep is not None or args.max_timestep is not None:
|
||||
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
||||
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
||||
timesteps = (min_time_step, max_time_step)
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
metadata = sai_model_spec.build_metadata(
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
lora,
|
||||
False,
|
||||
timestamp,
|
||||
title=title,
|
||||
reso=reso,
|
||||
is_stable_diffusion_ckpt=False,
|
||||
author=args.metadata_author,
|
||||
description=args.metadata_description,
|
||||
license=args.metadata_license,
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
stable_cascade=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
|
||||
# save text model
|
||||
if text_model is not None:
|
||||
text_model_sd = text_model.state_dict()
|
||||
|
||||
if save_dtype is not None:
|
||||
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
|
||||
|
||||
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
|
||||
save_file(text_model_sd, text_model_ckpt_file)
|
||||
|
||||
|
||||
def save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
text_model,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
|
||||
|
||||
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
|
||||
)
|
||||
|
||||
|
||||
def save_stage_c_model_on_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
text_model,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
|
||||
|
||||
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region sample generation
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
previewer,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
gdf,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||
# with torch.cuda.device(torch.cuda.current_device()):
|
||||
# torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
tokenizer,
|
||||
text_model,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 1024)
|
||||
height = prompt_dict.get("height", 1024)
|
||||
scale = prompt_dict.get("scale", 4)
|
||||
seed = prompt_dict.get("seed")
|
||||
# controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
|
||||
negative_prompt = "" if negative_prompt is None else negative_prompt
|
||||
cfg = scale
|
||||
timesteps = sample_steps
|
||||
shift = 2
|
||||
t_start = 1.0
|
||||
|
||||
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
input_ids = tokenizer(
|
||||
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
|
||||
|
||||
input_ids = tokenizer(
|
||||
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
|
||||
|
||||
device = accelerator.device
|
||||
dtype = stage_c.dtype
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
|
||||
|
||||
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf.sample(
|
||||
stage_c,
|
||||
conditions,
|
||||
stage_c_latent_shape,
|
||||
unconditions,
|
||||
device=device,
|
||||
cfg=cfg,
|
||||
shift=shift,
|
||||
timesteps=timesteps,
|
||||
t_start=t_start,
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
||||
sampled_c = sampled_c
|
||||
|
||||
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
|
||||
image = previewer(sampled_c)[0]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
image = image.cpu().numpy().transpose(1, 2, 0)
|
||||
image = image * 255
|
||||
image = image.astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = prompt_dict["enum"]
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def add_effnet_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--effnet_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_text_model_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--text_model_checkpoint_path",
|
||||
type=str,
|
||||
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
|
||||
)
|
||||
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_a_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_a_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_b_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_b_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_c_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_c_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_previewer_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--previewer_checkpoint_path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="path to previewer checkpoint / previewerのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_training_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--adaptive_loss_weight",
|
||||
action="store_true",
|
||||
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
|
||||
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
|
||||
)
|
||||
@@ -133,6 +133,7 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
)
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
@@ -856,7 +857,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
self.buckets_indices: List[BucketBatchIndex] = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
@@ -910,7 +911,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
def cache_latents(self, vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
|
||||
@@ -931,11 +932,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug, divisor)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
@@ -967,9 +968,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
|
||||
):
|
||||
assert len(tokenizers) == 2, "only support SDXL"
|
||||
"""
|
||||
最後の Text Encoder の pool がキャッシュされる。
|
||||
The last Text Encoder's pool is cached.
|
||||
"""
|
||||
# assert len(tokenizers) == 2, "only support SDXL"
|
||||
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
@@ -981,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
@@ -1006,7 +1011,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) if len(tokenizers) > 1 else None
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
@@ -1021,7 +1026,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||
)
|
||||
@@ -1270,7 +1275,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
||||
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
||||
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||
example["text_encoder_outputs2_list"] = (
|
||||
torch.stack(text_encoder_outputs2_list) if text_encoder_outputs2_list[0] is not None else None
|
||||
)
|
||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||
|
||||
if images[0] is not None:
|
||||
@@ -1327,7 +1334,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.caching_mode == "text":
|
||||
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None
|
||||
else:
|
||||
input_ids1 = None
|
||||
input_ids2 = None
|
||||
@@ -1599,12 +1606,15 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = os.path.splitext(image_key)[0] + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = os.path.join(subset.image_dir, image_key + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = None
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
@@ -1624,7 +1634,7 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
image_info.latents_npz = self.image_key_to_npz_file(subset, image_key)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
@@ -1638,7 +1648,7 @@ class FineTuningDataset(BaseDataset):
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
# flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
@@ -1648,9 +1658,12 @@ class FineTuningDataset(BaseDataset):
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
# flip は同一の .npz 内に格納するようにした:
|
||||
# そのためここでチェック漏れがあり実行時にエラーになる可能性があるので要検討
|
||||
# if subset.flip_aug:
|
||||
# has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
# flip_aug_in_subset = True
|
||||
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
@@ -1664,8 +1677,8 @@ class FineTuningDataset(BaseDataset):
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# if flip_aug_in_subset:
|
||||
# logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
@@ -1714,34 +1727,29 @@ class FineTuningDataset(BaseDataset):
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
image_info.latents_npz = None # image_info.latents_npz_flipped =
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
npz_file_norm = base_name + ".npz"
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
return npz_file_norm
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
npz_file_norm = os.path.join(subset.image_dir, base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
|
||||
if os.path.exists(npz_file_norm):
|
||||
return npz_file_norm
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
@@ -1943,17 +1951,26 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
self,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
device,
|
||||
weight_dtype,
|
||||
cache_to_disk=False,
|
||||
is_main_process=True,
|
||||
cache_file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
|
||||
)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
@@ -1986,8 +2003,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.disable_token_padding()
|
||||
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, divisor: int = 8) -> bool:
|
||||
expected_latents_size = (reso[1] // divisor, reso[0] // divisor) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
@@ -2079,7 +2096,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
|
||||
if show_input_ids:
|
||||
logger.info(f"input ids: {iid}")
|
||||
if "input_ids2" in example:
|
||||
if "input_ids2" in example and example["input_ids2"] is not None:
|
||||
logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
@@ -2256,7 +2273,7 @@ def trim_and_resize_if_required(
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
|
||||
vae: Union[AutoencoderKL, torch.nn.Module], cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
@@ -2311,23 +2328,36 @@ def cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||
):
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device) if input_ids2 is not None else None
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
dtype,
|
||||
)
|
||||
# TODO SDXL と Stable Cascade で統一する
|
||||
if len(tokenizers) == 1:
|
||||
# Stable Cascade
|
||||
b_hidden_state1, b_pool2 = get_hidden_states_stable_cascade(
|
||||
max_token_length, input_ids1, tokenizers[0], text_encoders[0], dtype
|
||||
)
|
||||
|
||||
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||
|
||||
b_hidden_state2 = [None] * input_ids1.shape[0]
|
||||
else:
|
||||
# SDXL
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
dtype,
|
||||
)
|
||||
|
||||
# ここでcpuに移動しておかないと、上書きされてしまう
|
||||
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
|
||||
b_hidden_state2 = b_hidden_state2.detach().to("cpu") if b_hidden_state2[0] is not None else b_hidden_state2 # b,n*75+2,1280
|
||||
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||
|
||||
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
|
||||
@@ -2340,18 +2370,25 @@ def cache_batch_text_encoder_outputs(
|
||||
|
||||
|
||||
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||
np.savez(
|
||||
npz_path,
|
||||
hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||
hidden_state2=hidden_state2.cpu().float().numpy(),
|
||||
pool2=pool2.cpu().float().numpy(),
|
||||
)
|
||||
save_kwargs = {
|
||||
"hidden_state1": hidden_state1.cpu().float().numpy(),
|
||||
"pool2": pool2.cpu().float().numpy(),
|
||||
}
|
||||
if hidden_state2 is not None:
|
||||
save_kwargs["hidden_state2"] = hidden_state2.cpu().float().numpy()
|
||||
np.savez(npz_path, **save_kwargs)
|
||||
# np.savez(
|
||||
# npz_path,
|
||||
# hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||
# hidden_state2=hidden_state2.cpu().float().numpy() if hidden_state2 is not None else None,
|
||||
# pool2=pool2.cpu().float().numpy(),
|
||||
# )
|
||||
|
||||
|
||||
def load_text_encoder_outputs_from_disk(npz_path):
|
||||
with np.load(npz_path) as f:
|
||||
hidden_state1 = torch.from_numpy(f["hidden_state1"])
|
||||
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
|
||||
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f and f["hidden_state2"] is not None else None
|
||||
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
|
||||
return hidden_state1, hidden_state2, pool2
|
||||
|
||||
@@ -2698,6 +2735,15 @@ def get_sai_model_spec(
|
||||
return metadata
|
||||
|
||||
|
||||
def add_tokenizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--tokenizer_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
parser.add_argument(
|
||||
@@ -2712,12 +2758,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
add_tokenizer_arguments(parser)
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
@@ -3205,6 +3246,16 @@ def verify_training_args(args: argparse.Namespace):
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
|
||||
if args.cache_latents_to_disk and not args.cache_latents:
|
||||
args.cache_latents = True
|
||||
logger.warning(
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
)
|
||||
|
||||
if not hasattr(args, "v_parameterization"):
|
||||
# Stable Cascade: skip following checks
|
||||
return
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
logger.warning(
|
||||
"v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません"
|
||||
@@ -3212,12 +3263,6 @@ def verify_training_args(args: argparse.Namespace):
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
if args.cache_latents_to_disk and not args.cache_latents:
|
||||
args.cache_latents = True
|
||||
logger.warning(
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
)
|
||||
|
||||
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
||||
# # Listを使って数えてもいいけど並べてしまえ
|
||||
# if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
||||
@@ -4297,6 +4342,54 @@ def get_hidden_states_sdxl(
|
||||
return hidden_states1, hidden_states2, pool2
|
||||
|
||||
|
||||
def get_hidden_states_stable_cascade(
|
||||
max_token_length: int,
|
||||
input_ids2: torch.Tensor,
|
||||
tokenizer2: CLIPTokenizer,
|
||||
text_encoder2: CLIPTextModel,
|
||||
weight_dtype: Optional[str] = None,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
):
|
||||
# ここに Stable Cascade 用のコードがあるのはとても気持ち悪いが、変に整理するよりわかりやすいので、とりあえずこのまま
|
||||
# It's very awkward to have Stable Cascade code here, but it's easier to understand than to organize it in a strange way, so for now it's as it is.
|
||||
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids2.size()[0]
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-1] # ** last layer **
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
|
||||
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||
|
||||
if max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
# pool はnの最初のものを使う
|
||||
pool2 = pool2[::n_size]
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
hidden_states2 = hidden_states2.to(weight_dtype)
|
||||
|
||||
return hidden_states2, pool2
|
||||
|
||||
|
||||
def default_if_none(value, default):
|
||||
return default if value is None else value
|
||||
|
||||
|
||||
@@ -841,9 +841,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
is_group_conv2d = is_conv2d and child_module.groups > 1
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
# if is_group_conv2d:
|
||||
# logger.info(f"skip group conv2d: {name}.{child_name}")
|
||||
# continue
|
||||
|
||||
if is_linear or (is_conv2d and not is_group_conv2d):
|
||||
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
@@ -915,6 +920,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# XXX temporary solution for Stable Cascade Stage C: replace all modules
|
||||
if "StageC" in unet.__class__.__name__:
|
||||
logger.info("replace all modules for Stable Cascade Stage C")
|
||||
target_modules = ["Linear", "Conv2d"]
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
|
||||
367
stable_cascade_gen_img.py
Normal file
367
stable_cascade_gen_img.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
||||
from PIL import Image
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
import library.stable_cascade as sc
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.device_utils as device_utils
|
||||
from library import train_util
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
|
||||
|
||||
def main(args):
|
||||
device = device_utils.get_preferred_device()
|
||||
|
||||
loading_device = device if not args.lowvram else "cpu"
|
||||
text_model_device = "cpu"
|
||||
|
||||
dtype = torch.float32
|
||||
if args.bf16:
|
||||
dtype = torch.bfloat16
|
||||
elif args.fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
text_model_dtype = torch.float32
|
||||
|
||||
# EfficientNet encoder
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
||||
effnet.eval().requires_grad_(False).to(loading_device)
|
||||
|
||||
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
generator_c.eval().requires_grad_(False).to(loading_device)
|
||||
# if args.xformers or args.sdpa:
|
||||
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
|
||||
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||
|
||||
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
generator_b.eval().requires_grad_(False).to(loading_device)
|
||||
# if args.xformers or args.sdpa:
|
||||
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
|
||||
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||
|
||||
# CLIP encoders
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
text_model = sc_utils.load_clip_text_model(
|
||||
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
|
||||
)
|
||||
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
|
||||
|
||||
# image_model = (
|
||||
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
|
||||
# )
|
||||
|
||||
# vqGAN
|
||||
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
stage_a.eval().requires_grad_(False)
|
||||
|
||||
# previewer
|
||||
if args.previewer_checkpoint_path is not None:
|
||||
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
previewer.eval().requires_grad_(False)
|
||||
else:
|
||||
previewer = None
|
||||
|
||||
# LoRA
|
||||
if args.network_module:
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
network_args = args.network_args[i]
|
||||
# TODO escape special chars
|
||||
network_args = network_args.split(";")
|
||||
for net_arg in network_args:
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
assert mergeable, "not-mergeable network is not supported yet."
|
||||
|
||||
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
|
||||
|
||||
# 謎のクラス gdf
|
||||
gdf_c = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=None,
|
||||
)
|
||||
gdf_b = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=None,
|
||||
)
|
||||
|
||||
# Stage C Parameters
|
||||
|
||||
# extras.sampling_configs["cfg"] = 4
|
||||
# extras.sampling_configs["shift"] = 2
|
||||
# extras.sampling_configs["timesteps"] = 20
|
||||
# extras.sampling_configs["t_start"] = 1.0
|
||||
|
||||
# # Stage B Parameters
|
||||
# extras_b.sampling_configs["cfg"] = 1.1
|
||||
# extras_b.sampling_configs["shift"] = 1
|
||||
# extras_b.sampling_configs["timesteps"] = 10
|
||||
# extras_b.sampling_configs["t_start"] = 1.0
|
||||
b_cfg = 1.1
|
||||
b_shift = 1
|
||||
b_timesteps = 10
|
||||
b_t_start = 1.0
|
||||
|
||||
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
||||
# height, width = 1024, 1024
|
||||
|
||||
while True:
|
||||
print("type caption:")
|
||||
# if Ctrl+Z is pressed, it will raise EOFError
|
||||
try:
|
||||
caption = input()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
caption = caption.strip()
|
||||
if caption == "":
|
||||
continue
|
||||
|
||||
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
|
||||
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
|
||||
|
||||
tokens = caption.split()
|
||||
width = height = 1024
|
||||
cfg = 4
|
||||
timesteps = 20
|
||||
shift = 2
|
||||
t_start = 1.0
|
||||
negative_prompt = ""
|
||||
seed = None
|
||||
|
||||
caption_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
if i == len(tokens) - 1:
|
||||
caption_tokens.append(token)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if token == "--w":
|
||||
width = int(tokens[i + 1])
|
||||
elif token == "--h":
|
||||
height = int(tokens[i + 1])
|
||||
elif token == "--l":
|
||||
cfg = float(tokens[i + 1])
|
||||
elif token == "--s":
|
||||
timesteps = int(tokens[i + 1])
|
||||
elif token == "--f":
|
||||
shift = float(tokens[i + 1])
|
||||
elif token == "--t":
|
||||
t_start = float(tokens[i + 1])
|
||||
elif token == "--n":
|
||||
negative_prompt = tokens[i + 1]
|
||||
elif token == "--d":
|
||||
seed = int(tokens[i + 1])
|
||||
else:
|
||||
caption_tokens.append(token)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
i += 2
|
||||
|
||||
caption = " ".join(caption_tokens)
|
||||
|
||||
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {
|
||||
"clip_text_pooled": uncond_pooled,
|
||||
"clip": uncond_pooled,
|
||||
"clip_text": uncond_text,
|
||||
"clip_img": zero_img_emb,
|
||||
}
|
||||
conditions_b = {}
|
||||
conditions_b.update(conditions)
|
||||
unconditions_b = {}
|
||||
unconditions_b.update(unconditions)
|
||||
|
||||
# seed everything
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf_c.sample(
|
||||
generator_c,
|
||||
conditions,
|
||||
stage_c_latent_shape,
|
||||
unconditions,
|
||||
device=device,
|
||||
cfg=cfg,
|
||||
shift=shift,
|
||||
timesteps=timesteps,
|
||||
t_start=t_start,
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
||||
sampled_c = sampled_c
|
||||
|
||||
conditions_b["effnet"] = sampled_c
|
||||
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
|
||||
|
||||
if previewer is not None:
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
preview = previewer(sampled_c)
|
||||
preview = preview.clamp(0, 1)
|
||||
preview = preview.permute(0, 2, 3, 1).squeeze(0)
|
||||
preview = preview.detach().float().cpu().numpy()
|
||||
preview = Image.fromarray((preview * 255).astype(np.uint8))
|
||||
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
generator_b = generator_b.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_b = gdf_b.sample(
|
||||
generator_b,
|
||||
conditions_b,
|
||||
stage_b_latent_shape,
|
||||
unconditions_b,
|
||||
device=device,
|
||||
cfg=b_cfg,
|
||||
shift=b_shift,
|
||||
timesteps=b_timesteps,
|
||||
t_start=b_t_start,
|
||||
)
|
||||
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
|
||||
sampled_b = sampled_b
|
||||
|
||||
if args.lowvram:
|
||||
generator_b = generator_b.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
stage_a = stage_a.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampled = stage_a.decode(sampled_b).float()
|
||||
# print(sampled.shape, sampled.min(), sampled.max())
|
||||
|
||||
if args.lowvram:
|
||||
stage_a = stage_a.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
|
||||
# float 0-1 to PIL Image
|
||||
sampled = sampled.clamp(0, 1)
|
||||
sampled = sampled.mul(255).to(dtype=torch.uint8)
|
||||
sampled = sampled.permute(0, 2, 3, 1)
|
||||
sampled = sampled.cpu().numpy()
|
||||
sampled = Image.fromarray(sampled[0])
|
||||
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_stage_a_arguments(parser)
|
||||
sc_utils.add_stage_b_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_previewer_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
parser.add_argument("--bf16", action="store_true")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--xformers", action="store_true")
|
||||
parser.add_argument("--sdpa", action="store_true")
|
||||
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
||||
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
||||
parser.add_argument(
|
||||
"--network_module",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
1091
stable_cascade_train_c_network.py
Normal file
1091
stable_cascade_train_c_network.py
Normal file
File diff suppressed because it is too large
Load Diff
564
stable_cascade_train_stage_c.py
Normal file
564
stable_cascade_train_stage_c.py
Normal file
@@ -0,0 +1,564 @@
|
||||
# training with captions
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.sdxl_train_util import add_sdxl_training_arguments
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.stable_cascade as sc
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
|
||||
# TODO add assertions for other unsupported options
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
loading_device = accelerator.device if args.lowram else "cpu"
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
||||
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, device=loading_device) # dtype is as it is
|
||||
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
|
||||
|
||||
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||
# Previewer is small enough to be loaded on CPU
|
||||
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
|
||||
previewer.eval()
|
||||
else:
|
||||
previewer = None
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
effnet,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
|
||||
32,
|
||||
)
|
||||
effnet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.gradient_checkpointing:
|
||||
accelerator.print("enable gradient checkpointing")
|
||||
stage_c.set_gradient_checkpointing(True)
|
||||
|
||||
train_stage_c = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
|
||||
if args.train_text_encoder:
|
||||
accelerator.print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
assert (
|
||||
train_text_encoder1
|
||||
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
|
||||
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
text_encoder1.train(train_text_encoder1)
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer,),
|
||||
(text_encoder1,),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
|
||||
stage_c.requires_grad_(True)
|
||||
if not train_stage_c:
|
||||
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
if train_stage_c:
|
||||
training_models.append(stage_c)
|
||||
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
|
||||
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
stage_c.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
stage_c.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_stage_c:
|
||||
stage_c = accelerator.prepare(stage_c)
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
# 謎のクラス GDF
|
||||
gdf = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
|
||||
)
|
||||
|
||||
# 以下2つの変数は、どうもデフォルトのままっぽい
|
||||
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
|
||||
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
# XXX Effnet preprocessing is included in encode method
|
||||
latents = effnet.encode(batch["images"].to(effnet_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
# # debug: decode latent with previewer and save it
|
||||
# import time
|
||||
# import numpy as np
|
||||
# from PIL import Image
|
||||
# ts = time.time()
|
||||
# images = previewer(latents.to(previewer.device, dtype=previewer.dtype))
|
||||
# for i, img in enumerate(images):
|
||||
# img = img.detach().cpu().numpy().transpose(1, 2, 0)
|
||||
# img = np.clip(img, 0, 1)
|
||||
# img = (img * 255).astype(np.uint8)
|
||||
# img = Image.fromarray(img)
|
||||
# img.save(f"logs/previewer_{i}_{ts}.png")
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
# unwrap_model is fine for models not wrapped by accelerator
|
||||
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
tokenizer,
|
||||
text_encoder1,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
|
||||
|
||||
# FORWARD PASS
|
||||
with torch.no_grad():
|
||||
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
|
||||
|
||||
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
|
||||
with accelerator.autocast():
|
||||
pred = stage_c(
|
||||
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
|
||||
loss_adjusted = (loss * loss_weight).mean()
|
||||
|
||||
if args.adaptive_loss_weight:
|
||||
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
|
||||
|
||||
accelerator.backward(loss_adjusted)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
text_encoder1 = accelerator.unwrap_model(text_encoder1)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state: # and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
sc_utils.save_stage_c_model_on_end(
|
||||
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
sc_utils.add_previewer_arguments(parser)
|
||||
sc_utils.add_training_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_sdxl_training_arguments(parser) # cache text encoder outputs
|
||||
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder / text encoderの学習率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
191
tools/stable_cascade_cache_latents.py
Normal file
191
tools/stable_cascade_cache_latents.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Stable Cascadeのlatentsをdiskにキャッシュする
|
||||
# cache latents of Stable Cascade to disk
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import stable_cascade_utils as sc_utils
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache latents arg
|
||||
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("latents")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
b_size = len(batch["images"])
|
||||
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
||||
flip_aug = batch["flip_aug"]
|
||||
random_crop = batch["random_crop"]
|
||||
bucket_reso = batch["bucket_reso"]
|
||||
|
||||
# バッチを分割して処理する
|
||||
for i in range(0, b_size, vae_batch_size):
|
||||
images = batch["images"][i : i + vae_batch_size]
|
||||
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
||||
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
||||
|
||||
image_infos = []
|
||||
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.image = image
|
||||
image_info.bucket_reso = bucket_reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
183
tools/stable_cascade_cache_text_encoder_outputs.py
Normal file
183
tools/stable_cascade_cache_text_encoder_outputs.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library import stable_cascade_utils as sc_utils
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache arg
|
||||
assert (
|
||||
args.cache_text_encoder_outputs_to_disk
|
||||
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
text_encoder = sc_utils.load_clip_text_model(
|
||||
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("text")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
absolute_paths = batch["absolute_paths"]
|
||||
input_ids1_list = batch["input_ids1_list"]
|
||||
|
||||
image_infos = []
|
||||
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
image_info
|
||||
|
||||
if args.skip_existing:
|
||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_info.input_ids1 = input_ids1
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
||||
train_util.cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
Reference in New Issue
Block a user