Compare commits

..

27 Commits

Author SHA1 Message Date
Kohya S
235a1ea2c6 Merge branch 'dev' into stable-cascade 2024-02-25 20:03:39 +09:00
Kohya S
cb648a2bf8 update readme 2024-02-25 20:03:00 +09:00
Kohya S
3a2a48c15d make LoRA compatible with ComfyUI #1119 2024-02-25 20:01:37 +09:00
Kohya S
e0acb10f31 Merge pull request #1137 from shirayu/replace_print_with_logger
Replaced print with logger
2024-02-25 10:34:19 +09:00
Kohya S
40f2c688db fix stage c weight is loaded in bf16/fp16 #1119 2024-02-25 09:39:53 +09:00
Kohya S
e4f8736c60 Merge branch 'dev' into stable-cascade 2024-02-25 08:58:27 +09:00
Yuta Hayashibe
5d5f39b6e6 Replaced print with logger 2024-02-25 01:24:11 +09:00
Kohya S
13f49d1e4a update readme 2024-02-22 23:50:10 +09:00
Kohya S
df7648245e update readme 2024-02-22 23:41:46 +09:00
Kohya S
3368fb1af7 Modify nn.MHA to attn with q/k/v 2024-02-22 23:39:28 +09:00
Kohya S
417f14d245 Merge pull request #1130 from sdbds/fixbugs
[stable-cascade]add save parser and fix lora scripts model name and hash
2024-02-22 12:30:59 +09:00
青龍聖者@bdsqlsz
86503cb945 add save parser and fix lora scripts model name and hash 2024-02-21 19:38:12 +08:00
Kohya S
d91b1d3793 update readme 2024-02-20 22:39:57 +09:00
Kohya S
70917077a6 update readme 2024-02-20 22:38:36 +09:00
Kohya S
69dbc50912 fix effnet encoder preprocess issue 2024-02-20 22:34:06 +09:00
Kohya S
985761ca43 fix to work without network module 2024-02-20 20:33:03 +09:00
Kohya S
71e03559e2 support LoRA training for Stable Cascade Stage C 2024-02-20 08:27:11 +09:00
Kohya S
806a6237fb minor fixes 2024-02-18 21:57:16 +09:00
Kohya S
9b0e532942 add command line sample 2024-02-18 21:40:36 +09:00
Kohya S
c26f01241f input prompt from console 2024-02-18 21:29:46 +09:00
Kohya S
ac71168939 add train_text_encoder arg 2024-02-18 21:29:10 +09:00
Kohya S
4e37d950d2 fix typos 2024-02-18 18:02:20 +09:00
Kohya S
4b5784eb44 update stable cascade stage C training #1119 2024-02-18 17:54:21 +09:00
Kohya S
856df07f49 Merge branch 'dev' into stable-cascade 2024-02-18 09:15:12 +09:00
Kohya S
80ef59c115 support text encoder training in stable cascade 2024-02-18 09:12:37 +09:00
Kohya S
319bbf8057 add stage c tmp training code 2024-02-17 23:59:20 +09:00
Kohya S
fa440208b7 add inference script 2024-02-17 17:57:30 +09:00
17 changed files with 5266 additions and 221 deletions

171
README.md
View File

@@ -1,3 +1,174 @@
# Training Stable Cascade Stage C
This is an experimental feature. There may be bugs.
__Feb 25, 2024 Update:__ Fixed a bug that the LoRA weights trained can be loaded in ComfyUI. If you still have a problem, please let me know.
__Feb 25, 2024 Update:__ Fixed a bug that Stage C training with mixed precision behaves the same as `--full_bf16` (fp16) regardless of `--full_bf16` (fp16) specified.
This is because the Stage C weights were loaded in bf16/fp16. With this fix, the memory usage without `--full_bf16` (fp16) specified will increase, so you may need to specify `--full_bf16` (fp16) as needed.
__Feb 22, 2024 Update:__ Fixed a bug that LoRA is not applied to some modules (to_q/k/v and to_out) in Attention. Also, the model structure of Stage C has been changed, and you can choose xformers and SDPA (SDPA was used before). Please specify `--sdpa` or `--xformers` option.
__Feb 20, 2024 Update:__ There was a problem with the preprocessing of the EfficientNetEncoder, and the latents became invalid (the saturation of the training results decreases). If you have cached `_sc_latents.npz` files with `--cache_latents_to_disk`, please delete them before training.
## Usage
Training is run with `stable_cascade_train_stage_c.py`.
The main options are the same as `sdxl_train.py`. The following options have been added.
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
The learning rate is set to 1e-4 in the official settings.
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
A scale of about 4 is suitable for sample image generation.
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
The code for training the Text Encoder is also written, but it is untested.
### Command line sample
```batch
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight
```
### About the dataset for fine tuning
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
## LoRA training
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
Text Encoder LoRA training is implemented, but untested.
## Image generation
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
The following prompt options are available.
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
# Stable Cascade Stage C の学習
実験的機能です。不具合があるかもしれません。
__2024/2/25 追記:__ 学習される LoRA の重みが ComfyUI で読み込めるよう修正しました。依然として不具合がある場合にはご連絡ください。
__2024/2/25 追記:__ Mixed precision 時のStage C の学習が、 `--full_bf16` (fp16) の指定に関わらず `--full_bf16` (fp16) 指定時と同じ動作となる(と思われる)不具合を修正しました。
Stage C の重みを bf16/fp16 で読み込んでいたためです。この修正により `--full_bf16` (fp16) 未指定時のメモリ使用量が増えますので、必要に応じて `--full_bf16` (fp16) を指定してください。
__2024/2/22 追記:__ LoRA が一部のモジュールAttention の to_q/k/v および to_outに適用されない不具合を修正しました。また Stage C のモデル構造を変更し xformers と SDPA を選べるようになりました(今までは SDPA が使用されていました)。`--sdpa` または `--xformers` オプションを指定してください。
__2024/2/20 追記:__ EfficientNetEncoder の前処理に不具合があり、latents が不正になっていました(学習結果の彩度が低下する現象が起きます)。`--cache_latents_to_disk` でキャッシュした `_sc_latents.npz` がある場合、いったん削除してから学習してください。
## 使い方
学習は `stable_cascade_train_stage_c.py` で行います。
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
学習率は、公式の設定では 1e-4 のようです。
初回は `--text_model_checkpoint_path``--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されますmixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
メモリ消費を抑えるための `--gradient_checkpointing``--full_bf16``--full_fp16`(未テスト)はそのまま使用できます。
サンプル画像生成時の Scale には 4 程度が適しているようです。
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
Text Encoder 学習のコードも書いてありますが、未テストです。
### コマンドラインのサンプル
[Command-line-sample](#command-line-sample)を参照してください。
### fine tuning方式のデータセットについて
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
## LoRA 等の学習
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習Pivotal Tuningも実装されていません。
Text Encoder の LoRA 学習は実装してありますが、未テストです。
## 画像生成
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
プロンプトオプションとして以下が使用できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.

View File

@@ -11,15 +11,19 @@ import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.stable_cascade_utils as sc_utils
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
@@ -42,7 +46,7 @@ def collate_fn_remove_corrupted(batch):
return batch
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
@@ -50,10 +54,11 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
base_name = image_key
relative_path = ""
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name) + ".npz"
return os.path.join(data_dir, relative_path, base_name) + ext
else:
return os.path.join(data_dir, base_name) + ".npz"
return os.path.join(data_dir, base_name) + ext
def main(args):
@@ -83,13 +88,20 @@ def main(args):
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
if not args.stable_cascade:
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
divisor = 8
else:
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
divisor = 32
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -154,6 +166,10 @@ def main(args):
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
# 追加情報を記録
metadata[image_key]["original_size"] = (image.width, image.height)
metadata[image_key]["train_resized_size"] = resized_size
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert (
@@ -168,9 +184,9 @@ def main(args):
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshape等を確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
continue
# バッチへ追加
@@ -208,7 +224,14 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--stable_cascade",
action="store_true",
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
)
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -231,10 +254,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
@@ -242,7 +271,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--skip_existing",

View File

@@ -61,6 +61,12 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
@@ -82,12 +88,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
logger.info("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
print("Enable xformers for U-Net")
logger.info("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
@@ -95,7 +101,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
logger.info("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
@@ -112,7 +118,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs):
@@ -168,7 +174,7 @@ def replace_vae_attn_to_memory_efficient():
def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
logger.info("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
@@ -224,7 +230,7 @@ def replace_vae_attn_to_xformers():
def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa")
logger.info("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
@@ -386,10 +392,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
@@ -467,7 +473,7 @@ class PipelineLike:
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
@@ -576,7 +582,7 @@ class PipelineLike:
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None:
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
@@ -742,8 +748,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
logger.warning("gradual_latent is not supported for this scheduler. Ignoring.")
logger.warning(f"{self.scheduler.__class__.__name__}")
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -792,7 +798,7 @@ class PipelineLike:
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
@@ -1013,7 +1019,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}")
logger.info(f"BREAK pad_len: {pad_len}")
for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0:
@@ -1043,7 +1049,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -1344,7 +1350,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
@@ -1488,9 +1494,9 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデルを読み込む
if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
@@ -1510,7 +1516,7 @@ def main(args):
else:
# if `text_encoder_2` subdirectory exists, sdxl
is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2"))
print(f"SDXL: {is_sdxl}")
logger.info(f"SDXL: {is_sdxl}")
if is_sdxl:
if args.clip_skip is None:
@@ -1526,10 +1532,10 @@ def main(args):
args.clip_skip = 2 if args.v2 else 1
if use_stable_diffusion_format:
print("load StableDiffusion checkpoint")
logger.info("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else:
print("load Diffusers pretrained models")
logger.info("load Diffusers pretrained models")
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae
@@ -1553,7 +1559,7 @@ def main(args):
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded")
logger.info("additional VAE loaded")
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
@@ -1562,7 +1568,7 @@ def main(args):
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
print("loading tokenizer")
logger.info("loading tokenizer")
if is_sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
@@ -1654,7 +1660,7 @@ def main(args):
noise = None
if noise == None:
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1
@@ -1715,7 +1721,7 @@ def main(args):
vae_dtype = dtype
if args.no_half_vae:
print("set vae_dtype to float32")
logger.info("set vae_dtype to float32")
vae_dtype = torch.float32
vae.to(vae_dtype).to(device)
vae.eval()
@@ -1739,10 +1745,10 @@ def main(args):
network_merge = args.network_merge_n_models
else:
network_merge = 0
print(f"network_merge: {network_merge}")
logger.info(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
logger.info("import network module: {network_module}")
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -1760,7 +1766,7 @@ def main(args):
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
logger.info(f"load network weights from: {network_weight}")
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
@@ -1768,7 +1774,7 @@ def main(args):
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
logger.info(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs
@@ -1778,20 +1784,20 @@ def main(args):
mergeable = network.is_mergeable()
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
logger.warning("network is not mergiable. ignore merge option.")
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoders, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
logger.info(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if network_pre_calc:
print("backup original weights")
logger.info("backup original weights")
network.backup_weights()
networks.append(network)
@@ -1805,7 +1811,7 @@ def main(args):
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
logger.info("import upscaler module: {args.highres_fix_upscaler}")
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
@@ -1814,7 +1820,7 @@ def main(args):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
logger.info("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
@@ -1833,7 +1839,7 @@ def main(args):
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:
for i, model_file in enumerate(args.control_net_lllite_models):
print(f"loading ControlNet-LLLite: {model_file}")
logger.info(f"loading ControlNet-LLLite: {model_file}")
from safetensors.torch import load_file
@@ -1867,7 +1873,7 @@ def main(args):
), "ControlNet and ControlNet-LLLite cannot be used at the same time"
if args.opt_channels_last:
print(f"set optimizing: channels last")
logger.info(f"set optimizing: channels last")
for text_encoder in text_encoders:
text_encoder.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last)
@@ -1894,7 +1900,7 @@ def main(args):
)
pipe.set_control_nets(control_nets)
pipe.set_control_net_lllites(control_net_lllites)
print("pipeline is ready.")
logger.info("pipeline is ready.")
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
@@ -1965,7 +1971,7 @@ def main(args):
token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings)
token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
assert (
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
), f"token ids1 is not ordered"
@@ -2002,7 +2008,7 @@ def main(args):
# promptを取得する
prompt_list = None
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
logger.info(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
@@ -2019,7 +2025,7 @@ def main(args):
spec.loader.exec_module(module)
return module
print(f"reading prompts from module: {args.from_module}")
logger.info(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)
prompter = prompt_module.get_prompter(args, pipe, networks)
@@ -2050,7 +2056,7 @@ def main(args):
for p in paths:
image = Image.open(p)
if image.mode != "RGB":
print(f"convert image to RGB from {image.mode}: {p}")
logger.info(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB")
images.append(image)
@@ -2066,14 +2072,14 @@ def main(args):
return resized
if args.image_path is not None:
print(f"load image for img2img: {args.image_path}")
logger.info(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img")
logger.info(f"loaded {len(init_images)} images for img2img")
# CLIP Vision
if args.clip_vision_strength is not None:
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
vision_model.to(device, dtype)
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
@@ -2081,22 +2087,22 @@ def main(args):
pipe.clip_vision_model = vision_model
pipe.clip_vision_processor = processor
pipe.clip_vision_strength = args.clip_vision_strength
print(f"CLIP Vision model loaded.")
logger.info(f"CLIP Vision model loaded.")
else:
init_images = None
if args.mask_path is not None:
print(f"load mask for inpainting: {args.mask_path}")
logger.info(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
print(f"loaded {len(mask_images)} mask images for inpainting")
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
else:
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and prompter is None and not args.interactive:
print("get prompts from images' metadata")
logger.info("get prompts from images' metadata")
prompt_list = []
for img in init_images:
if "prompt" in img.text:
@@ -2127,17 +2133,17 @@ def main(args):
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
print(f"resize img2img source images to {w}*{h}")
logger.info(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
print(f"resize img2img mask images to {w}*{h}")
logger.info(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
logger.info("use mask as region")
size = None
for i, network in enumerate(networks):
@@ -2162,14 +2168,14 @@ def main(args):
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for ControlNet guidance: {args.guide_image_path}")
logger.info(f"load image for ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
logger.info(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
print(
logger.warning(
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
)
guide_images = None
@@ -2200,7 +2206,7 @@ def main(args):
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
@@ -2219,7 +2225,7 @@ def main(args):
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
logger.info("process 1st stage")
batch_1st = []
for _, base, ext in batch:
@@ -2264,7 +2270,7 @@ def main(args):
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
logger.info("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
@@ -2437,7 +2443,7 @@ def main(args):
n.restore_weights()
for n in networks:
n.pre_calculation()
print("pre-calculation... done")
logger.info("pre-calculation... done")
images = pipe(
prompts,
@@ -2520,7 +2526,7 @@ def main(args):
cv2.waitKey()
cv2.destroyAllWindows()
except ImportError:
print(
logger.warning(
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
)
@@ -2535,7 +2541,7 @@ def main(args):
# interactive
valid = False
while not valid:
print("\nType prompt:")
logger.info("\nType prompt:")
try:
raw_prompt = input()
except EOFError:
@@ -2595,74 +2601,74 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
length = len(prompter) if hasattr(prompter, "__len__") else 0
print(f"prompt {prompt_index+1}/{length}: {prompt}")
logger.info(f"prompt {prompt_index+1}/{length}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
logger.info(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
logger.info(f"height: {height}")
continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m:
original_width = int(m.group(1))
print(f"original width: {original_width}")
logger.info(f"original width: {original_width}")
continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m:
original_height = int(m.group(1))
print(f"original height: {original_height}")
logger.info(f"original height: {original_height}")
continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
print(f"original width negative: {original_width_negative}")
logger.info(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
print(f"original height negative: {original_height_negative}")
logger.info(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
print(f"crop top: {crop_top}")
logger.info(f"crop top: {crop_top}")
continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m:
crop_left = int(m.group(1))
print(f"crop left: {crop_left}")
logger.info(f"crop left: {crop_left}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
print(f"steps: {steps}")
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
print(f"seeds: {seeds}")
logger.info(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
print(f"scale: {scale}")
logger.info(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
@@ -2671,25 +2677,25 @@ def main(args):
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
logger.info(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
print(f"strength: {strength}")
logger.info(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
logger.info(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}")
logger.info(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
@@ -2697,89 +2703,89 @@ def main(args):
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
logger.info(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
logger.info(f"deep shrink ratio: {ds_ratio}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
# override Deep Shrink
if ds_depth_1 is not None:
@@ -2825,7 +2831,7 @@ def main(args):
if seed is None:
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
print(f"seed: {seed}")
logger.info(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
@@ -2841,7 +2847,7 @@ def main(args):
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
logger.warning(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
@@ -2903,12 +2909,14 @@ def main(args):
process_batch(batch_data, highres_fix)
batch_data.clear()
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
)

View File

@@ -489,10 +489,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
# region xformersとか使う部分独自に書き換えるので関係なし
@@ -971,8 +971,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -3314,42 +3314,42 @@ def main(args):
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
@@ -3369,7 +3369,7 @@ def main(args):
if gl_unsharp_params is not None:
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
print(unsharp_params)
logger.info(f'{unsharp_params}')
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:

View File

@@ -3,6 +3,11 @@ import gc
import torch
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
@@ -59,7 +64,7 @@ def get_preferred_device() -> torch.device:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
logger.info(f"get_preferred_device() -> {device}")
return device
@@ -77,8 +82,8 @@ def init_ipex():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
logger.error("failed to initialize ipex: {error_message}")
else:
return
except Exception as e:
print("failed to initialize ipex:", e)
logger.error("failed to initialize ipex: {e}")

View File

@@ -6,8 +6,10 @@ import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
@@ -55,11 +57,13 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_STABLE_CASCADE = "stable-cascade"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
@@ -113,6 +117,7 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
stable_cascade: Optional[bool] = None,
):
# if state_dict is None, hash is not calculated
@@ -124,7 +129,9 @@ def build_metadata(
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
if stable_cascade:
arch = ARCH_STABLE_CASCADE
elif sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
@@ -142,9 +149,11 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if stable_cascade:
impl = IMPL_STABILITY_AI_STABLE_CASCADE
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -236,7 +245,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -250,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

1654
library/stable_cascade.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,668 @@
import argparse
import json
import math
import os
import time
from typing import List
import numpy as np
import toml
import torch
import torchvision
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from accelerate import init_empty_weights, Accelerator, PartialState
from PIL import Image
from library import stable_cascade as sc
from library.sdxl_model_util import _load_state_dict_on_device
from library.device_utils import clean_memory_on_device
from library.train_util import (
save_sd_model_on_epoch_end_or_stepwise_common,
save_sd_model_on_train_end_common,
line_to_prompt_dict,
get_hidden_states_stable_cascade,
)
from library import sai_model_spec
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
latent_height = math.ceil(height / compression_factor_b)
latent_width = math.ceil(width / compression_factor_b)
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
latent_height = math.ceil(height / compression_factor_a)
latent_width = math.ceil(width / compression_factor_a)
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
return stage_c_latent_shape, stage_b_latent_shape
# region load and save
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
effnet = sc.EfficientNetEncoder()
effnet_checkpoint = load_file(effnet_checkpoint_path)
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
logger.info(info)
del effnet_checkpoint
return effnet
def load_tokenizer(args: argparse.Namespace):
# TODO commonize with sdxl_train_util.load_tokenizers
logger.info("prepare tokenizers")
original_paths = [CLIP_TEXT_MODEL_NAME]
tokenizers = []
for i, original_path in enumerate(original_paths):
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
tokenizers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
logger.info(f"update token length: {args.max_token_length}")
return tokenizers[0]
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
# Generator
logger.info(f"Instantiating Stage C generator")
with init_empty_weights():
generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_c
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
logger.info(f"Instantiating Stage B generator")
with init_empty_weights():
generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_b
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
# CLIP encoders
logger.info(f"Loading CLIP text model")
if save_text_model or text_model_checkpoint_path is None:
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
if save_text_model:
sd = text_model.state_dict()
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
save_file(sd, text_model_checkpoint_path)
else:
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
# copy from sdxl_model_util.py
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model = CLIPTextModelWithProjection(text_model2_cfg)
text_model_checkpoint = load_file(text_model_checkpoint_path)
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
logger.info(info)
return text_model
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
stage_a = sc.StageA().to(device)
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
info = stage_a.load_state_dict(
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
)
logger.info(info)
return stage_a
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
previewer = sc.Previewer().to(device)
previewer_checkpoint = load_file(previewer_checkpoint_path)
info = previewer.load_state_dict(
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
)
logger.info(info)
return previewer
def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and out_proj
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "in_proj_bias" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
elif "in_proj_weight" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
return state_dict
def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and out_proj to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
elif "to_q.weight" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "out_proj.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
elif "out_proj.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
return state_dict
def get_sai_model_spec(args, lora=False):
timestamp = time.time()
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
None,
False,
False,
False,
lora,
False,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=False,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
stable_cascade=True,
)
return metadata
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
save_file(state_dict, ckpt_file, metadata=sai_metadata)
# save text model
if text_model is not None:
text_model_sd = text_model.state_dict()
if save_dtype is not None:
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
save_file(text_model_sd, text_model_ckpt_file)
def save_stage_c_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_epoch_end_or_stepwise_common(
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
)
def save_stage_c_model_on_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
# endregion
# region sample generation
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
previewer,
tokenizer,
text_encoder,
stage_c,
gdf,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
stage_c = accelerator.unwrap_model(stage_c)
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
tokenizer,
text_model,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 1024)
scale = prompt_dict.get("scale", 4)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
negative_prompt = "" if negative_prompt is None else negative_prompt
cfg = scale
timesteps = sample_steps
shift = 2
t_start = 1.0
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
input_ids = tokenizer(
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
device = accelerator.device
dtype = stage_c.dtype
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf.sample(
stage_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
image = previewer(sampled_c)[0]
image = torch.clamp(image, 0, 1)
image = image.cpu().numpy().transpose(1, 2, 0)
image = image * 255
image = image.astype(np.uint8)
image = Image.fromarray(image)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion
def add_effnet_arguments(parser):
parser.add_argument(
"--effnet_checkpoint_path",
type=str,
required=True,
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
)
return parser
def add_text_model_arguments(parser):
parser.add_argument(
"--text_model_checkpoint_path",
type=str,
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
return parser
def add_stage_a_arguments(parser):
parser.add_argument(
"--stage_a_checkpoint_path",
type=str,
required=True,
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
)
return parser
def add_stage_b_arguments(parser):
parser.add_argument(
"--stage_b_checkpoint_path",
type=str,
required=True,
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
)
return parser
def add_stage_c_arguments(parser):
parser.add_argument(
"--stage_c_checkpoint_path",
type=str,
required=True,
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
)
return parser
def add_previewer_arguments(parser):
parser.add_argument(
"--previewer_checkpoint_path",
type=str,
required=False,
help="path to previewer checkpoint / previewerのチェックポイントのパス",
)
return parser
def add_training_arguments(parser):
parser.add_argument(
"--adaptive_loss_weight",
action="store_true",
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
)

View File

@@ -133,6 +133,7 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
class ImageInfo:
@@ -856,7 +857,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = []
self.buckets_indices: List[BucketBatchIndex] = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
@@ -910,7 +911,7 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
def cache_latents(self, vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
@@ -931,11 +932,11 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug, divisor)
if cache_available: # do not add to batch
continue
@@ -967,9 +968,13 @@ class BaseDataset(torch.utils.data.Dataset):
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
):
assert len(tokenizers) == 2, "only support SDXL"
"""
最後の Text Encoder の pool がキャッシュされる。
The last Text Encoder's pool is cached.
"""
# assert len(tokenizers) == 2, "only support SDXL"
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -981,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset):
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
te_out_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
@@ -1006,7 +1011,7 @@ class BaseDataset(torch.utils.data.Dataset):
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) if len(tokenizers) > 1 else None
batch.append((info, input_ids1, input_ids2))
if len(batch) >= self.batch_size:
@@ -1021,7 +1026,7 @@ class BaseDataset(torch.utils.data.Dataset):
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
@@ -1270,7 +1275,9 @@ class BaseDataset(torch.utils.data.Dataset):
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_outputs2_list"] = (
torch.stack(text_encoder_outputs2_list) if text_encoder_outputs2_list[0] is not None else None
)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
if images[0] is not None:
@@ -1327,7 +1334,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None
else:
input_ids1 = None
input_ids2 = None
@@ -1599,12 +1606,15 @@ class FineTuningDataset(BaseDataset):
# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
abs_path = os.path.splitext(image_key)[0] + ".npz"
if not os.path.exists(abs_path):
abs_path = os.path.splitext(image_key)[0] + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if not os.path.exists(abs_path):
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
if not os.path.exists(abs_path):
abs_path = os.path.join(subset.image_dir, image_key + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
if not os.path.exists(abs_path):
abs_path = None
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
@@ -1624,7 +1634,7 @@ class FineTuningDataset(BaseDataset):
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
image_info.latents_npz = self.image_key_to_npz_file(subset, image_key)
self.register_image(image_info, subset)
@@ -1638,7 +1648,7 @@ class FineTuningDataset(BaseDataset):
# check existence of all npz files
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
flip_aug_in_subset = False
# flip_aug_in_subset = False
npz_any = False
npz_all = True
@@ -1648,9 +1658,12 @@ class FineTuningDataset(BaseDataset):
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
# flip は同一の .npz 内に格納するようにした:
# そのためここでチェック漏れがあり実行時にエラーになる可能性があるので要検討
# if subset.flip_aug:
# has_npz = has_npz and image_info.latents_npz_flipped is not None
# flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
@@ -1664,8 +1677,8 @@ class FineTuningDataset(BaseDataset):
logger.warning(
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
)
if flip_aug_in_subset:
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# if flip_aug_in_subset:
# logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -1714,34 +1727,29 @@ class FineTuningDataset(BaseDataset):
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
image_info.latents_npz = None # image_info.latents_npz_flipped =
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
npz_file_norm = base_name + ".npz"
if not os.path.exists(npz_file_norm):
npz_file_norm = base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if os.path.exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
return npz_file_norm
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None
return None
# image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
npz_file_norm = os.path.join(subset.image_dir, base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
if os.path.exists(npz_file_norm):
return npz_file_norm
return npz_file_norm, npz_file_flip
return None
class ControlNetDataset(BaseDataset):
@@ -1943,17 +1951,26 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
self,
tokenizers,
text_encoders,
device,
weight_dtype,
cache_to_disk=False,
is_main_process=True,
cache_file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
dataset.cache_text_encoder_outputs(
tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
@@ -1986,8 +2003,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding()
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, divisor: int = 8) -> bool:
expected_latents_size = (reso[1] // divisor, reso[0] // divisor) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
@@ -2079,7 +2096,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
if show_input_ids:
logger.info(f"input ids: {iid}")
if "input_ids2" in example:
if "input_ids2" in example and example["input_ids2"] is not None:
logger.info(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
@@ -2256,7 +2273,7 @@ def trim_and_resize_if_required(
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
vae: Union[AutoencoderKL, torch.nn.Module], cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
@@ -2311,23 +2328,36 @@ def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
input_ids2 = input_ids2.to(text_encoders[1].device) if input_ids2 is not None else None
with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
# TODO SDXL と Stable Cascade で統一する
if len(tokenizers) == 1:
# Stable Cascade
b_hidden_state1, b_pool2 = get_hidden_states_stable_cascade(
max_token_length, input_ids1, tokenizers[0], text_encoders[0], dtype
)
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_pool2 = b_pool2.detach().to("cpu") # b,1280
b_hidden_state2 = [None] * input_ids1.shape[0]
else:
# SDXL
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
# ここでcpuに移動しておかないと、上書きされてしまう
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
b_hidden_state2 = b_hidden_state2.detach().to("cpu") if b_hidden_state2[0] is not None else b_hidden_state2 # b,n*75+2,1280
b_pool2 = b_pool2.detach().to("cpu") # b,1280
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
@@ -2340,18 +2370,25 @@ def cache_batch_text_encoder_outputs(
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
np.savez(
npz_path,
hidden_state1=hidden_state1.cpu().float().numpy(),
hidden_state2=hidden_state2.cpu().float().numpy(),
pool2=pool2.cpu().float().numpy(),
)
save_kwargs = {
"hidden_state1": hidden_state1.cpu().float().numpy(),
"pool2": pool2.cpu().float().numpy(),
}
if hidden_state2 is not None:
save_kwargs["hidden_state2"] = hidden_state2.cpu().float().numpy()
np.savez(npz_path, **save_kwargs)
# np.savez(
# npz_path,
# hidden_state1=hidden_state1.cpu().float().numpy(),
# hidden_state2=hidden_state2.cpu().float().numpy() if hidden_state2 is not None else None,
# pool2=pool2.cpu().float().numpy(),
# )
def load_text_encoder_outputs_from_disk(npz_path):
with np.load(npz_path) as f:
hidden_state1 = torch.from_numpy(f["hidden_state1"])
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f and f["hidden_state2"] is not None else None
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
return hidden_state1, hidden_state2, pool2
@@ -2698,6 +2735,15 @@ def get_sai_model_spec(
return metadata
def add_tokenizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument(
@@ -2712,12 +2758,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
)
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
add_tokenizer_arguments(parser)
def add_optimizer_arguments(parser: argparse.ArgumentParser):
@@ -3205,6 +3246,16 @@ def verify_training_args(args: argparse.Namespace):
global HIGH_VRAM
HIGH_VRAM = True
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
logger.warning(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
if not hasattr(args, "v_parameterization"):
# Stable Cascade: skip following checks
return
if args.v_parameterization and not args.v2:
logger.warning(
"v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません"
@@ -3212,12 +3263,6 @@ def verify_training_args(args: argparse.Namespace):
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
logger.warning(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
# # Listを使って数えてもいいけど並べてしまえ
# if args.noise_offset is not None and args.multires_noise_iterations is not None:
@@ -4297,6 +4342,54 @@ def get_hidden_states_sdxl(
return hidden_states1, hidden_states2, pool2
def get_hidden_states_stable_cascade(
max_token_length: int,
input_ids2: torch.Tensor,
tokenizer2: CLIPTokenizer,
text_encoder2: CLIPTextModel,
weight_dtype: Optional[str] = None,
accelerator: Optional[Accelerator] = None,
):
# ここに Stable Cascade 用のコードがあるのはとても気持ち悪いが、変に整理するよりわかりやすいので、とりあえずこのまま
# It's very awkward to have Stable Cascade code here, but it's easier to understand than to organize it in a strange way, so for now it's as it is.
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids2.size()[0]
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-1] # ** last layer **
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
if weight_dtype is not None:
# this is required for additional network training
hidden_states2 = hidden_states2.to(weight_dtype)
return hidden_states2, pool2
def default_if_none(value, default):
return default if value is None else value

View File

@@ -327,10 +327,10 @@ class DyLoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}")
logger.info(f"create LoRA for Text Encoder {index}")
else:
index = None
print(f"create LoRA for Text Encoder")
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)

View File

@@ -841,9 +841,14 @@ class LoRANetwork(torch.nn.Module):
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
is_group_conv2d = is_conv2d and child_module.groups > 1
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
# if is_group_conv2d:
# logger.info(f"skip group conv2d: {name}.{child_name}")
# continue
if is_linear or (is_conv2d and not is_group_conv2d):
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
lora_name = lora_name.replace(".", "_")
dim = None
@@ -915,6 +920,11 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# XXX temporary solution for Stable Cascade Stage C: replace all modules
if "StageC" in unet.__class__.__name__:
logger.info("replace all modules for Stable Cascade Stage C")
target_modules = ["Linear", "Conv2d"]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")

View File

@@ -380,10 +380,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
@@ -789,8 +789,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -2614,84 +2614,84 @@ def main(args):
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:

367
stable_cascade_gen_img.py Normal file
View File

@@ -0,0 +1,367 @@
import argparse
import importlib
import math
import os
import random
import time
import numpy as np
from safetensors.torch import load_file, save_file
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from PIL import Image
from accelerate import init_empty_weights
import library.stable_cascade as sc
import library.stable_cascade_utils as sc_utils
import library.device_utils as device_utils
from library import train_util
from library.sdxl_model_util import _load_state_dict_on_device
def main(args):
device = device_utils.get_preferred_device()
loading_device = device if not args.lowvram else "cpu"
text_model_device = "cpu"
dtype = torch.float32
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
text_model_dtype = torch.float32
# EfficientNet encoder
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
effnet.eval().requires_grad_(False).to(loading_device)
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# CLIP encoders
tokenizer = sc_utils.load_tokenizer(args)
text_model = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
)
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
# image_model = (
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
# )
# vqGAN
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
stage_a.eval().requires_grad_(False)
# previewer
if args.previewer_checkpoint_path is not None:
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
previewer.eval().requires_grad_(False)
else:
previewer = None
# LoRA
if args.network_module:
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
assert mergeable, "not-mergeable network is not supported yet."
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
# 謎のクラス gdf
gdf_c = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=None,
)
gdf_b = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=None,
)
# Stage C Parameters
# extras.sampling_configs["cfg"] = 4
# extras.sampling_configs["shift"] = 2
# extras.sampling_configs["timesteps"] = 20
# extras.sampling_configs["t_start"] = 1.0
# # Stage B Parameters
# extras_b.sampling_configs["cfg"] = 1.1
# extras_b.sampling_configs["shift"] = 1
# extras_b.sampling_configs["timesteps"] = 10
# extras_b.sampling_configs["t_start"] = 1.0
b_cfg = 1.1
b_shift = 1
b_timesteps = 10
b_t_start = 1.0
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
# height, width = 1024, 1024
while True:
print("type caption:")
# if Ctrl+Z is pressed, it will raise EOFError
try:
caption = input()
except EOFError:
break
caption = caption.strip()
if caption == "":
continue
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
tokens = caption.split()
width = height = 1024
cfg = 4
timesteps = 20
shift = 2
t_start = 1.0
negative_prompt = ""
seed = None
caption_tokens = []
i = 0
while i < len(tokens):
token = tokens[i]
if i == len(tokens) - 1:
caption_tokens.append(token)
i += 1
continue
if token == "--w":
width = int(tokens[i + 1])
elif token == "--h":
height = int(tokens[i + 1])
elif token == "--l":
cfg = float(tokens[i + 1])
elif token == "--s":
timesteps = int(tokens[i + 1])
elif token == "--f":
shift = float(tokens[i + 1])
elif token == "--t":
t_start = float(tokens[i + 1])
elif token == "--n":
negative_prompt = tokens[i + 1]
elif token == "--d":
seed = int(tokens[i + 1])
else:
caption_tokens.append(token)
i += 1
continue
i += 2
caption = " ".join(caption_tokens)
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
input_ids = tokenizer(
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {
"clip_text_pooled": uncond_pooled,
"clip": uncond_pooled,
"clip_text": uncond_text,
"clip_img": zero_img_emb,
}
conditions_b = {}
conditions_b.update(conditions)
unconditions_b = {}
unconditions_b.update(unconditions)
# seed everything
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
if args.lowvram:
generator_c = generator_c.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf_c.sample(
generator_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
conditions_b["effnet"] = sampled_c
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
if previewer is not None:
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
preview = previewer(sampled_c)
preview = preview.clamp(0, 1)
preview = preview.permute(0, 2, 3, 1).squeeze(0)
preview = preview.detach().float().cpu().numpy()
preview = Image.fromarray((preview * 255).astype(np.uint8))
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))
if args.lowvram:
generator_c = generator_c.to(loading_device)
device_utils.clean_memory_on_device(device)
generator_b = generator_b.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_b = gdf_b.sample(
generator_b,
conditions_b,
stage_b_latent_shape,
unconditions_b,
device=device,
cfg=b_cfg,
shift=b_shift,
timesteps=b_timesteps,
t_start=b_t_start,
)
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
sampled_b = sampled_b
if args.lowvram:
generator_b = generator_b.to(loading_device)
device_utils.clean_memory_on_device(device)
stage_a = stage_a.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampled = stage_a.decode(sampled_b).float()
# print(sampled.shape, sampled.min(), sampled.max())
if args.lowvram:
stage_a = stage_a.to(loading_device)
device_utils.clean_memory_on_device(device)
# float 0-1 to PIL Image
sampled = sampled.clamp(0, 1)
sampled = sampled.mul(255).to(dtype=torch.uint8)
sampled = sampled.permute(0, 2, 3, 1)
sampled = sampled.cpu().numpy()
sampled = Image.fromarray(sampled[0])
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
sc_utils.add_effnet_arguments(parser)
train_util.add_tokenizer_arguments(parser)
sc_utils.add_stage_a_arguments(parser)
sc_utils.add_stage_b_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument(
"--network_module",
type=str,
default=None,
nargs="*",
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
)
parser.add_argument(
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
)
parser.add_argument(
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
)
parser.add_argument(
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
args = parser.parse_args()
main(args)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,564 @@
# training with captions
import argparse
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.sdxl_train_util import add_sdxl_training_arguments
import library.stable_cascade_utils as sc_utils
import library.stable_cascade as sc
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
# TODO add assertions for other unsupported options
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = sc_utils.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
loading_device = accelerator.device if args.lowram else "cpu"
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, device=loading_device) # dtype is as it is
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
# Previewer is small enough to be loaded on CPU
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
previewer.eval()
else:
previewer = None
# モデルに xformers とか memory efficient attention を組み込む
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
effnet,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
32,
)
effnet.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
accelerator.print("enable gradient checkpointing")
stage_c.set_gradient_checkpointing(True)
train_stage_c = args.learning_rate > 0
train_text_encoder1 = False
if args.train_text_encoder:
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
assert (
train_text_encoder1
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder1.train(train_text_encoder1)
else:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder1.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer,),
(text_encoder1,),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
)
accelerator.wait_for_everyone()
if not cache_latents:
effnet.requires_grad_(False)
effnet.eval()
effnet.to(accelerator.device, dtype=effnet_dtype)
stage_c.requires_grad_(True)
if not train_stage_c:
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
training_models = []
params_to_optimize = []
if train_stage_c:
training_models.append(stage_c)
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_stage_c:
stage_c = accelerator.prepare(stage_c)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
# 謎のクラス GDF
gdf = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
)
# 以下2つの変数は、どうもデフォルトのままっぽい
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
# XXX Effnet preprocessing is included in encode method
latents = effnet.encode(batch["images"].to(effnet_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
# # debug: decode latent with previewer and save it
# import time
# import numpy as np
# from PIL import Image
# ts = time.time()
# images = previewer(latents.to(previewer.device, dtype=previewer.dtype))
# for i, img in enumerate(images):
# img = img.detach().cpu().numpy().transpose(1, 2, 0)
# img = np.clip(img, 0, 1)
# img = (img * 255).astype(np.uint8)
# img = Image.fromarray(img)
# img.save(f"logs/previewer_{i}_{ts}.png")
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
input_ids1 = input_ids1.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
args.max_token_length,
input_ids1,
tokenizer,
text_encoder1,
None if not args.full_fp16 else weight_dtype,
accelerator,
)
else:
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
# FORWARD PASS
with torch.no_grad():
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
with accelerator.autocast():
pred = stage_c(
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()
if args.adaptive_loss_weight:
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
accelerator.backward(loss_adjusted)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
is_main_process = accelerator.is_main_process
# if is_main_process:
stage_c = accelerator.unwrap_model(stage_c)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
accelerator.end_training()
if args.save_state: # and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
sc_utils.save_stage_c_model_on_end(
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
sc_utils.add_effnet_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_text_model_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_training_arguments(parser)
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_sdxl_training_arguments(parser) # cache text encoder outputs
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder / text encoderの学習率",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -0,0 +1,191 @@
# Stable Cascadeのlatentsをdiskにキャッシュする
# cache latents of Stable Cascade to disk
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import stable_cascade_utils as sc_utils
from library import config_util
from library import train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
logger.info("load model")
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_effnet_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -0,0 +1,183 @@
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import stable_cascade_utils as sc_utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
logger.info("load model")
text_encoder = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
)
text_encoders = [text_encoder]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
image_infos = []
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)