diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index b6865dbf..0149dcdd 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.16.15 + uses: crate-ci/typos@v1.24.3 diff --git a/README-ja.md b/README-ja.md index 29c33a65..27cc56c3 100644 --- a/README-ja.md +++ b/README-ja.md @@ -1,12 +1,12 @@ -SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 - -SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 - ## リポジトリについて Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 [README in English](./README.md) ←更新情報はこちらにあります +開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。 + +FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。 + GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 以下のスクリプトがあります。 @@ -21,6 +21,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど * [データセット設定](./docs/config_README-ja.md) +* [SDXL学習](./docs/train_SDXL-en.md) (英語版) * [DreamBoothの学習について](./docs/train_db_README-ja.md) * [fine-tuningのガイド](./docs/fine_tune_README_ja.md): * [LoRAの学習について](./docs/train_network_README-ja.md) @@ -44,9 +45,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 - -以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 +スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) @@ -59,21 +58,21 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` コマンドプロンプトでも同一です。 -(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) +注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 + +この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。 accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) -※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。 - ```txt - This machine - No distributed training @@ -87,41 +86,6 @@ accelerate configの質問には以下のように答えてください。(bf1 ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) -### オプション:`bitsandbytes`(8bit optimizer)を使う - -`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 - -Windowsでは0.35.0または0.41.1を推奨します。 - -- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 -- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 - -注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 - -以下の手順に従い、`bitsandbytes`をインストールしてください。 - -### 0.35.0を使う場合 - -PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 - -```powershell -cd sd-scripts -.\venv\Scripts\activate -pip install bitsandbytes==0.35.0 - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -``` - -### 0.41.1を使う場合 - -jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 - -```powershell -python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui -``` - ## アップグレード 新しいリリースがあった場合、以下のコマンドで更新できます。 @@ -151,4 +115,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause +## その他の情報 +### LoRAの名称について + +`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。 + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) + + Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) + + 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA + +デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。 + + + +### 学習中のサンプル画像生成 + +プロンプトファイルは例えば以下のようになります。 + +``` +# prompt 1 +masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 + +# prompt 2 +masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 +``` + + `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 + + * `--n` Negative prompt up to the next option. + * `--w` Specifies the width of the generated image. + * `--h` Specifies the height of the generated image. + * `--d` Specifies the seed of the generated image. + * `--l` Specifies the CFG scale of the generated image. + * `--s` Specifies the number of steps in the generation. + + `( )` や `[ ]` などの重みづけも動作します。 diff --git a/README.md b/README.md index 0edaca25..51ac07a1 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). - This repository contains training, generation and utility scripts for Stable Diffusion. [__Change History__](#change-history) is moved to the bottom of the page. @@ -7,6 +5,11 @@ This repository contains training, generation and utility scripts for Stable Dif [日本語版READMEはこちら](./README-ja.md) +The development version is in the `dev` branch. Please check the dev branch for the latest changes. + +FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch. + + For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! This repository contains the scripts for: @@ -20,9 +23,9 @@ This repository contains the scripts for: ## About requirements.txt -These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.) +The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below. -The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work. +The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work. ## Links to usage documentation @@ -32,11 +35,13 @@ Most of the documents are written in Japanese. * [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc... * [Chinese version](./docs/train_README-zh.md) +* [SDXL training](./docs/train_SDXL-en.md) (English version) * [Dataset config](./docs/config_README-ja.md) + * [English version](./docs/config_README-en.md) * [DreamBooth training guide](./docs/train_db_README-ja.md) * [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md): -* [training LoRA](./docs/train_network_README-ja.md) -* [training Textual Inversion](./docs/train_ti_README-ja.md) +* [Training LoRA](./docs/train_network_README-ja.md) +* [Training Textual Inversion](./docs/train_ti_README-ja.md) * [Image generation](./docs/gen_img_README-ja.md) * note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad) @@ -64,14 +69,18 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` -__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section. +If `python -m venv` shows only `python`, change `python` to `py`. + +__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. + +This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`. -### LoRAの名称について - -`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。 - -1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) - - Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA - -2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) - - 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA - -LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。 - -LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。 - -## Sample image generation during training +### Sample image generation during training A prompt file might look like this, for example ``` @@ -340,26 +392,3 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. - -## サンプル画像生成 -プロンプトファイルは例えば以下のようになります。 - -``` -# prompt 1 -masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 - -# prompt 2 -masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 -``` - - `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 - - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. - - `( )` や `[ ]` などの重みづけも動作します。 - diff --git a/XTI_hijack.py b/XTI_hijack.py index ec084945..93bc1c0b 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,7 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex +init_ipex() + from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/docs/config_README-en.md b/docs/config_README-en.md new file mode 100644 index 00000000..83bea329 --- /dev/null +++ b/docs/config_README-en.md @@ -0,0 +1,384 @@ +Original Source by kohya-ss + +First version: +A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150 + +Some parts are manually added. + +# Config Readme + +This README is about the configuration files that can be passed with the `--dataset_config` option. + +## Overview + +By passing a configuration file, users can make detailed settings. + +* Multiple datasets can be configured + * For example, by setting `resolution` for each dataset, they can be mixed and trained. + * In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed. +* Settings can be changed for each subset + * A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset. + * Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later. + +The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML. + + +Here is an example of a configuration file written in TOML. + +```toml +[general] +shuffle_caption = true +caption_extension = '.txt' +keep_tokens = 1 + +# This is a DreamBooth-style dataset +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 2 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + class_tokens = 'hoge girl' + # This subset uses keep_tokens = 2 (the value of the parent datasets) + + [[datasets.subsets]] + image_dir = 'C:\fuga' + class_tokens = 'fuga boy' + keep_tokens = 3 + + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' + class_tokens = 'human' + keep_tokens = 1 + +# This is a fine-tuning dataset +[[datasets]] +resolution = [768, 768] +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'C:\piyo' + metadata_file = 'C:\piyo\piyo_md.json' + # This subset uses keep_tokens = 1 (the value of [general]) +``` + +In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2). + +## Settings for datasets and subsets + +Settings for datasets and subsets are divided into several registration locations. + +* `[general]` + * This is where options that apply to all datasets or all subsets are specified. + * If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence. +* `[[datasets]]` + * `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified. + * If there are subset-specific settings, the subset-specific settings take precedence. +* `[[datasets.subsets]]` + * `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified. + +Here is an image showing the correspondence between image directories and registration locations in the previous example. + +``` +C:\ +├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ +├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] +├─ reg -> [[datasets.subsets]] No.3 ┘ | +└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ +``` + +The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`. + +The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding. + +Additionally, the available options may vary depending on the method that the learning approach supports. + +* Options specific to the DreamBooth method +* Options specific to the fine-tuning method +* Options available when using the caption dropout technique + +When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both. +When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset. +In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets. + +In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem. + +Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs. + +### Common options for all learning methods + +These are options that can be specified regardless of the learning method. + +#### Data set specific options + +These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`. + + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * This corresponds to the command-line argument `--train_batch_size`. + +These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. + +#### Options for Subsets + +These options are related to subset configuration. + +| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o | +| `caption_suffix` | `", from side"` | o | o | o | +| `caption_separator` | (not specified) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | + +* `num_repeats` + * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. +* `caption_prefix`, `caption_suffix` + * Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`. +* `caption_separator` + * Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set. +* `keep_tokens_separator` + * Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc. +* `secondary_separator` + * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. +* `enable_wildcard` + * Enables wildcard notation. This will be explained later. + +### DreamBooth-specific options + +DreamBooth-specific options only exist as subsets-specific options. + +#### Subset-specific options + +Options related to the configuration of DreamBooth subsets. + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o (required) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `"sks girl"` | - | - | o | +| `cache_info` | `false` | o | o | o | +| `is_reg` | `false` | - | - | o | + +Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`. + +* `image_dir` + * Specifies the path to the image directory. This is a required option. + * Images must be placed directly under the directory. +* `class_tokens` + * Sets the class tokens. + * Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur. +* `cache_info` + * Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`. + * Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more. +* `is_reg` + * Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization. + +### Fine-tuning method specific options + +The options for the fine-tuning method only exist for subset-specific options. + +#### Subset-specific options + +These options are related to the configuration of the fine-tuning method's subsets. + +| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) | + +* `image_dir` + * Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so. + * The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file. + * The images must be placed directly under the directory. +* `metadata_file` + * Specify the path to the metadata file used for the subset. This is a required option. + * It is equivalent to the command-line argument `--in_json`. + * Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets. + +### Options available when caption dropout method can be used + +The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified. + +#### Subset-specific options + +Options related to the setting of subsets that caption dropout can be used for. + +| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## Behavior when there are duplicate subsets + +In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored. + +However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions. + +```toml +# If data sets exist separately, they are not considered duplicates and are both used for training. + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## Command Line Argument and Configuration File + +There are options in the configuration file that have overlapping roles with command line argument options. + +The following command line argument options are ignored if a configuration file is passed: + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file. + +| Command Line Argument Option | Prioritized Configuration File Option | +| ------------------------------- | ------------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## Error Guide + +Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem. + +As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug. + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name. + * The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting. +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful. +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it. + +## Miscellaneous + +### Multi-line captions + +By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption. + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +It can be combined with wildcard notation. + +In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format. + +The tags in the metadata (`tags`) are added to each line of the caption. + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` + +In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc. + +### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc. + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). + +### Example of caption, enable_wildcard notation: `enable_wildcard = true` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). + +### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. + diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 69a03f6c..cc74c341 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -1,5 +1,3 @@ -For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. - `--dataset_config` で渡すことができる設定ファイルに関する説明です。 ## 概要 @@ -140,12 +138,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `shuffle_caption` | `true` | o | o | o | | `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | | `caption_suffix` | `“, from side”` | o | o | o | +| `caption_separator` | (通常は設定しません) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 * `caption_prefix`, `caption_suffix` * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。 +* `caption_separator` + * タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。 + +* `keep_tokens_separator` + * キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。 + +* `secondary_separator` + * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。 + +* `enable_wildcard` + * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 + ### DreamBooth 方式専用のオプション DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 @@ -159,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。 | `image_dir` | `‘C:\hoge’` | - | - | o(必須) | | `caption_extension` | `".txt"` | o | o | o | | `class_tokens` | `“sks girl”` | - | - | o | +| `cache_info` | `false` | o | o | o | | `is_reg` | `false` | - | - | o | まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。 @@ -169,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。 * `class_tokens` * クラストークンを設定します。 * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。 +* `cache_info` + * 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。 + * キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。 * `is_reg` * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。 @@ -280,4 +298,89 @@ resolution = 768 * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。 * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。 +## その他 +### 複数行キャプション + +`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +ワイルドカード記法と組み合わせることも可能です。 + +メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。 + +メタデータのタグ (`tags`) は、キャプションの各行に追加されます。 + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` + +この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。 + +### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等 + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 + +### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 + +### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 diff --git a/docs/gen_img_README-ja.md b/docs/gen_img_README-ja.md index cf35f1df..8f4442d0 100644 --- a/docs/gen_img_README-ja.md +++ b/docs/gen_img_README-ja.md @@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt - `--network_show_meta` : 追加ネットワークのメタデータを表示します。 + +--- + +# About Gradual Latent + +Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options. + +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. +- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. +- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. +- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. + +Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. + +__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. + +It is more effective with SD 1.5. It is quite subtle with SDXL. + +# Gradual Latent について + +latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。 + +- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。 +- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 +- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 +- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 + +それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 + +サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 + +SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。 + diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index c871f076..d186bf24 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。 +- `--sample_at_first` + + 学習開始前にサンプル出力します。学習前との比較ができます。 + - `--sample_prompts` サンプル出力用プロンプトのファイルを指定します。 diff --git a/docs/train_SDXL-en.md b/docs/train_SDXL-en.md new file mode 100644 index 00000000..a4c55b3f --- /dev/null +++ b/docs/train_SDXL-en.md @@ -0,0 +1,84 @@ +## SDXL training + +The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL. + +### Training scripts for SDXL + +- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. + - `--full_bf16` option is added. Thanks to KohakuBlueleaf! + - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. + - The full bfloat16 training might be unstable. Please use it at your own risk. + - The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. + - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. +- `prepare_buckets_latents.py` now supports SDXL fine-tuning. + +- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. + +- Both scripts has following additional options: + - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. + - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. + +- `--weighted_captions` option is not supported yet for both scripts. + +- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. + - `--cache_text_encoder_outputs` is not supported. + - There are two options for captions: + 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. + 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. + - See below for the format of the embeddings. + +- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. + +### Utility scripts for SDXL + +- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. + - The options are almost the same as `sdxl_train.py'. See the help message for the usage. + - Please launch the script as follows: + `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` + - This script should work with multi-GPU, but it is not tested in my environment. + +- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. + - The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage. + +- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage. + +### Tips for SDXL training + +- The default resolution of SDXL is 1024x1024. +- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. +- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use one of 8bit optimizers or Adafactor optimizer. + - Use lower dim (4 to 8 for 8GB GPU). +- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. +- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. +- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. + +Example of the optimizer settings for Adafactor with the fixed learning rate: +```toml +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +lr_scheduler = "constant_with_warmup" +lr_warmup_steps = 100 +learning_rate = 4e-7 # SDXL original learning rate +``` + +### Format of Textual Inversion embeddings for SDXL + +```python +from safetensors.torch import save_file + +state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} +save_file(state_dict, file) +``` + +### ControlNet-LLLite + +ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. + diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md index dbdc1fea..1f6a78d5 100644 --- a/docs/train_lllite_README-ja.md +++ b/docs/train_lllite_README-ja.md @@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k ## モデルの学習 ### データセットの準備 -通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。 -たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。 +(finetuning 方式の dataset はサポートしていません。) + +conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 + +たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。 ```toml [[datasets.subsets]] diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index 04dc12da..a05f87f5 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1 ### Preparing the dataset -In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. +In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. + +(We do not support the finetuning method dataset.) ```toml [[datasets.subsets]] diff --git a/docs/wd14_tagger_README-en.md b/docs/wd14_tagger_README-en.md new file mode 100644 index 00000000..34f44882 --- /dev/null +++ b/docs/wd14_tagger_README-en.md @@ -0,0 +1,88 @@ +# Image Tagging using WD14Tagger + +This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger). + +Using onnx for inference is recommended. Please install onnx with the following command: + +```powershell +pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 +``` + +The model weights will be automatically downloaded from Hugging Face. + +# Usage + +Run the script to perform tagging. + +```powershell +python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id --batch_size +``` + +For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be: + +```powershell +python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data +``` + +On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option). + +Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension. + +![Generated tag files](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![Tags and image](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +## Example + +To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice): + +``` +python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 + --batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive + --use_rating_tags_as_last_tag --character_tags_first --character_tag_expand + --always_first_tags "1girl,1boy" ..\train_data +``` + +## Available Repository IDs + +[SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`. + +# Options + +## General Options + +- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately. +- `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity. +- `--caption_extension`: File extension for caption files. Default is `.txt`. +- `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used. +- `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease. +- `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`. +- `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`. +- `--recursive`: If specified, subfolders within the specified folder will also be processed recursively. +- `--append_tags`: Append tags to existing tag files. +- `--frequency_tags`: Output tag frequencies. +- `--debug`: Debug mode. Outputs debug information if specified. + +## Model Download + +- `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`. +- `--force_download`: Re-download model files if specified. + +## Tag Editing + +- `--remove_underscore`: Remove underscores from output tags. +- `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`. +- `--use_rating_tags`: Output rating tags at the beginning of the tags. +- `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags. +- `--character_tags_first`: Output character tags first. +- `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`. +- `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`. +- `--caption_separator`: Separate tags with this string in the output file. Default is `, `. +- `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \ + For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag). + +When using `tag_replacement`, it is applied after `character_tag_expand`. + +When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores. + +When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`. diff --git a/docs/wd14_tagger_README-ja.md b/docs/wd14_tagger_README-ja.md new file mode 100644 index 00000000..58e9ede9 --- /dev/null +++ b/docs/wd14_tagger_README-ja.md @@ -0,0 +1,88 @@ +# WD14Taggerによるタグ付け + +こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。 + +onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。 + +```powershell +pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 +``` + +モデルの重みはHugging Faceから自動的にダウンロードしてきます。 + +# 使い方 + +スクリプトを実行してタグ付けを行います。 +``` +python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ> +``` + +レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。 + +``` +python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data +``` + +初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。 + +タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。 + +![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +## 記述例 + +Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。 + +``` +python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 + --batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive + --use_rating_tags_as_last_tag --character_tags_first --character_tag_expand + --always_first_tags "1girl,1boy" ..\train_data +``` + +## 使用可能なリポジトリID + +[SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。 + +# オプション + +## 一般オプション + +- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。 +- `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。 +- `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。 +- `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。 +- `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。 +- `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。 +- `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。 +- `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。 +- `--append_tags` : 既存のタグファイルにタグを追加します。 +- `--frequency_tags` : タグの頻度を出力します。 +- `--debug` : デバッグモード。指定するとデバッグ情報を出力します。 + +## モデルのダウンロード + +- `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。 +- `--force_download` : 指定するとモデルファイルを再ダウンロードします。 + +## タグ編集関連 + +- `--remove_underscore` : 出力するタグからアンダースコアを削除します。 +- `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。 +- `--use_rating_tags` : タグの最初にレーティングタグを出力します。 +- `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。 +- `--character_tags_first` : キャラクタータグを最初に出力します。 +- `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。 +- `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。 +- `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。 +- `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\ + たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。 + +`tag_replacement` は `character_tag_expand` の後に適用されます。 + +`remove_underscore` 指定時は、`undesired_tags`、`always_first_tags`、`tag_replacement` はアンダースコアを含めずに指定してください。 + +`caption_separator` 指定時は、`undesired_tags`、`always_first_tags` は `caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。 + diff --git a/fine_tune.py b/fine_tune.py index 52e84c43..c7e6bbd2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -2,27 +2,29 @@ # XXX dropped option: hypernetwork training import argparse -import gc import math import os from multiprocessing import Value import toml from tqdm import tqdm + import torch +from library import deepspeed_utils +from library.device_utils import init_ipex, clean_memory_on_device -try: - import intel_extension_for_pytorch as ipex +init_ipex() - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -42,6 +44,8 @@ from library.custom_train_functions import ( def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -54,11 +58,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -91,7 +95,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -102,11 +106,12 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) @@ -157,15 +162,13 @@ def train(args): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -192,7 +195,7 @@ def train(args): if not cache_latents: vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) for m in training_models: m.requires_grad_(True) @@ -212,8 +215,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -228,7 +231,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -245,16 +250,23 @@ def train(args): unet.to(weight_dtype) text_encoder.to(weight_dtype) - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler ) + training_models = [ds_model] else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -294,10 +306,15 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -308,13 +325,13 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with accelerator.accumulate(*training_models): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype) latents = latents * 0.18215 b_size = latents.shape[0] @@ -337,7 +354,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -351,11 +368,11 @@ def train(args): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: @@ -363,7 +380,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -454,7 +471,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -464,21 +481,25 @@ def train(args): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--learning_rate_te", @@ -486,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser @@ -494,6 +520,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7851fb08..13b69ffd 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -21,6 +21,10 @@ import torch.nn.functional as F import os from urllib.parse import urlparse from timm.models.hub import download_cached_file +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class BLIP_Base(nn.Module): def __init__(self, @@ -130,8 +134,9 @@ class BLIP_Decoder(nn.Module): def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): image_embeds = self.visual_encoder(image) - if not sample: - image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + # recent version of transformers seems to do repeat_interleave automatically + # if not sample: + # image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} @@ -235,6 +240,6 @@ def load_checkpoint(model,url_or_filename): del state_dict[key] msg = model.load_state_dict(state_dict,strict=False) - print('load checkpoint from %s'%url_or_filename) + logger.info('load checkpoint from %s'%url_or_filename) return model,msg diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 68839ecc..5aeb1742 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -8,6 +8,10 @@ import json import re from tqdm import tqdm +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') @@ -36,13 +40,13 @@ def clean_tags(image_key, tags): tokens = tags.split(", rating") if len(tokens) == 1: # WD14 taggerのときはこちらになるのでメッセージは出さない - # print("no rating:") - # print(f"{image_key} {tags}") + # logger.info("no rating:") + # logger.info(f"{image_key} {tags}") pass else: if len(tokens) > 2: - print("multiple ratings:") - print(f"{image_key} {tags}") + logger.info("multiple ratings:") + logger.info(f"{image_key} {tags}") tags = tokens[0] tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 @@ -124,43 +128,43 @@ def clean_caption(caption): def main(args): if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.load(f) else: - print("no metadata / メタデータファイルがありません") + logger.error("no metadata / メタデータファイルがありません") return - print("cleaning captions and tags.") + logger.info("cleaning captions and tags.") image_keys = list(metadata.keys()) for image_key in tqdm(image_keys): tags = metadata[image_key].get('tags') if tags is None: - print(f"image does not have tags / メタデータにタグがありません: {image_key}") + logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") else: org = tags tags = clean_tags(image_key, tags) metadata[image_key]['tags'] = tags if args.debug and org != tags: - print("FROM: " + org) - print("TO: " + tags) + logger.info("FROM: " + org) + logger.info("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: - print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") + logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: org = caption caption = clean_caption(caption) metadata[image_key]['caption'] = caption if args.debug and org != caption: - print("FROM: " + org) - print("TO: " + caption) + logger.info("FROM: " + org) + logger.info("TO: " + caption) # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding='utf-8') as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -178,10 +182,10 @@ if __name__ == '__main__': args, unknown = parser.parse_known_args() if len(unknown) == 1: - print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") - print("All captions and tags in the metadata are processed.") - print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") - print("メタデータ内のすべてのキャプションとタグが処理されます。") + logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") + logger.warning("All captions and tags in the metadata are processed.") + logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") + logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") args.in_json = args.out_json args.out_json = unknown[0] elif len(unknown) > 0: diff --git a/finetune/make_captions.py b/finetune/make_captions.py index b20c4106..489bdbcc 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -9,14 +9,22 @@ from pathlib import Path from PIL import Image from tqdm import tqdm import numpy as np + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms from torchvision.transforms.functional import InterpolationMode sys.path.append(os.path.dirname(__file__)) -from blip.blip import blip_decoder +from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_SIZE = 384 @@ -47,7 +55,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset): # convert to tensor temporarily so dataloader will accept it tensor = IMAGE_TRANSFORM(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -74,19 +82,21 @@ def main(args): args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path cwd = os.getcwd() - print("Current Working Directory is: ", cwd) + logger.info(f"Current Working Directory is: {cwd}") os.chdir("finetune") + if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): + args.caption_weights = os.path.join("..", args.caption_weights) - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") - print(f"loading BLIP caption: {args.caption_weights}") + logger.info(f"loading BLIP caption: {args.caption_weights}") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) - print("BLIP loaded") + logger.info("BLIP loaded") # captioningする def run_batch(path_imgs): @@ -106,7 +116,7 @@ def main(args): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f'{image_path} {caption}') # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -136,7 +146,7 @@ def main(args): raw_image = raw_image.convert("RGB") img_tensor = IMAGE_TRANSFORM(raw_image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, img_tensor)) @@ -146,7 +156,7 @@ def main(args): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc42..edeebadf 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -5,12 +5,19 @@ import re from pathlib import Path from PIL import Image from tqdm import tqdm + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from transformers import AutoProcessor, AutoModelForCausalLM from transformers.generation.utils import GenerationMixin import library.train_util as train_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -35,8 +42,8 @@ def remove_words(captions, debug): for pat in PATTERN_REPLACE: cap = pat.sub("", cap) if debug and cap != caption: - print(caption) - print(cap) + logger.info(caption) + logger.info(cap) removed_caps.append(cap) return removed_caps @@ -70,16 +77,16 @@ def main(args): GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch """ - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") + logger.info(f"loading GIT: {args.model_id}") git_processor = AutoProcessor.from_pretrained(args.model_id) git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") + logger.info("GIT loaded") # captioningする def run_batch(path_imgs): @@ -97,7 +104,7 @@ def main(args): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f"{image_path} {caption}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -126,7 +133,7 @@ def main(args): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -137,7 +144,7 @@ def main(args): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 241f6f90..89f71747 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -5,72 +5,96 @@ from typing import List from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") - else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - print("merge caption texts to metadata json.") - for image_path in tqdm(image_paths): - caption_path = image_path.with_suffix(args.caption_extension) - caption = caption_path.read_text(encoding='utf-8').strip() + logger.info("merge caption texts to metadata json.") + for image_path in tqdm(image_paths): + caption_path = image_path.with_suffix(args.caption_extension) + caption = caption_path.read_text(encoding="utf-8").strip() - if not os.path.exists(caption_path): - caption_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['caption'] = caption - if args.debug: - print(image_key, caption) + metadata[image_key]["caption"] = caption + if args.debug: + logger.info(f"{image_key} {caption}") - # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--debug", action="store_true", help="debug mode") + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子" + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() + args = parser.parse_args() - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention - main(args) + main(args) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index db1bff6d..ce22d990 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -5,67 +5,89 @@ from typing import List from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") - else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - print("merge tags to metadata json.") - for image_path in tqdm(image_paths): - tags_path = image_path.with_suffix(args.caption_extension) - tags = tags_path.read_text(encoding='utf-8').strip() + logger.info("merge tags to metadata json.") + for image_path in tqdm(image_paths): + tags_path = image_path.with_suffix(args.caption_extension) + tags = tags_path.read_text(encoding="utf-8").strip() - if not os.path.exists(tags_path): - tags_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['tags'] = tags - if args.debug: - print(image_key, tags) + metadata[image_key]["tags"] = tags + if args.debug: + logger.info(f"{image_key} {tags}") - # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--caption_extension", type=str, default=".txt", - help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode, print tags") + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument( + "--caption_extension", + type=str, + default=".txt", + help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子", + ) + parser.add_argument("--debug", action="store_true", help="debug mode, print tags") - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - main(args) + args = parser.parse_args() + main(args) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3..0389da38 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -8,13 +8,21 @@ from tqdm import tqdm import numpy as np from PIL import Image import cv2 + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms import library.model_util as model_util import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( [ @@ -51,22 +59,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive): def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") if args.bucket_reso_steps % 32 > 0: - print( + logger.warning( f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding="utf-8") as f: metadata = json.load(f) else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") + logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}") return weight_dtype = torch.float32 @@ -89,7 +97,7 @@ def main(args): if not args.bucket_no_upscale: bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -130,7 +138,7 @@ def main(args): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] @@ -183,15 +191,15 @@ def main(args): for i, reso in enumerate(bucket_manager.resos): count = bucket_counts.get(reso, 0) if count > 0: - print(f"bucket {i} {reso}: {count}") + logger.info(f"bucket {i} {reso}: {count}") img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") + logger.info(f"mean ar error: {np.mean(img_ar_errors)}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding="utf-8") as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 965edd7e..a327bbd6 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,6 +11,12 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # from wd14 tagger IMAGE_SIZE = 448 @@ -56,12 +62,12 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset): try: image = Image.open(img_path).convert("RGB") image = preprocess_image(image) - tensor = torch.tensor(image) + # tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None - return (tensor, img_path) + return (image, img_path) def collate_fn_remove_corrupted(batch): @@ -75,36 +81,44 @@ def collate_fn_remove_corrupted(batch): def main(args): + # model location is model_dir + repo_id + # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash + model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + if not os.path.exists(model_location) or args.force_download: + os.makedirs(args.model_dir, exist_ok=True) + logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: + files = ["selected_tags.csv"] files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(model_location, SUB_DIR), + force_download=True, + force_filename=file, + ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) + hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) else: - print("using existing wd14 tagger model") + logger.info("using existing wd14 tagger model") - # 画像を読み込む + # モデルを読み込む if args.onnx: + import torch import onnx import onnxruntime as ort - onnx_path = f"{args.model_dir}/model.onnx" - print("Running wd14 tagger with onnx") - print(f"loading onnx model: {onnx_path}") + onnx_path = f"{model_location}/model.onnx" + logger.info("Running wd14 tagger with onnx") + logger.info(f"loading onnx model: {onnx_path}") if not os.path.exists(onnx_path): raise Exception( @@ -116,58 +130,112 @@ def main(args): input_name = model.graph.input[0].name try: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value - except: + except Exception: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str: + if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0: # some rebatch model may use 'N' as dynamic axes - print( + logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size del model - ort_sess = ort.InferenceSession( - onnx_path, - providers=["CUDAExecutionProvider"] - if "CUDAExecutionProvider" in ort.get_available_providers() - else ["CPUExecutionProvider"], - ) + if "OpenVINOExecutionProvider" in ort.get_available_providers(): + # requires provider options for gpu support + # fp16 causes nonsense outputs + ort_sess = ort.InferenceSession( + onnx_path, + providers=(["OpenVINOExecutionProvider"]), + provider_options=[{'device_type' : "GPU_FP32"}], + ) + else: + ort_sess = ort.InferenceSession( + onnx_path, + providers=( + ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else + ["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else + ["CPUExecutionProvider"] + ), + ) else: from tensorflow.keras.models import load_model - model = load_model(f"{args.model_dir}") + model = load_model(f"{model_location}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] + line = [row for row in reader] + header = line[0] # tag_id,name,category,count + rows = line[1:] assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - general_tags = [row[1] for row in rows[1:] if row[2] == "0"] - character_tags = [row[1] for row in rows[1:] if row[2] == "4"] + rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] + general_tags = [row[1] for row in rows[0:] if row[2] == "0"] + character_tags = [row[1] for row in rows[0:] if row[2] == "4"] + + # preprocess tags in advance + if args.character_tag_expand: + for i, tag in enumerate(character_tags): + if tag.endswith(")"): + # chara_name_(series) -> chara_name, series + # chara_name_(costume)_(series) -> chara_name_(costume), series + tags = tag.split("(") + character_tag = "(".join(tags[:-1]) + if character_tag.endswith("_"): + character_tag = character_tag[:-1] + series_tag = tags[-1].replace(")", "") + character_tags[i] = character_tag + args.caption_separator + series_tag + + if args.remove_underscore: + rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags] + general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags] + character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags] + + if args.tag_replacement is not None: + # escape , and ; in tag_replacement: wd14 tag names may contain , and ; + escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####") + tag_replacements = escaped_tag_replacements.split(";") + for tag_replacement in tag_replacements: + tags = tag_replacement.split(",") # source, target + assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" + + source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags] + logger.info(f"replacing tag: {source} -> {target}") + + if source in general_tags: + general_tags[general_tags.index(source)] = target + elif source in character_tags: + character_tags[character_tags.index(source)] = target + elif source in rating_tags: + rating_tags[rating_tags.index(source)] = target # 画像を読み込む - train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") tag_freq = {} - undesired_tags = set(args.undesired_tags.split(",")) + caption_separator = args.caption_separator + stripped_caption_separator = caption_separator.strip() + undesired_tags = args.undesired_tags.split(stripped_caption_separator) + undesired_tags = set([tag.strip() for tag in undesired_tags if tag.strip() != ""]) + + always_first_tags = None + if args.always_first_tags is not None: + always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""] def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - if len(imgs) < args.batch_size: - imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + # if len(imgs) < args.batch_size: + # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -175,46 +243,64 @@ def main(args): probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] - - # それ以降はタグなのでconfidenceがthresholdより高いものを追加する - # Everything else is tags: pick any where prediction confidence > threshold combined_tags = [] - general_tag_text = "" + rating_tag_text = "" character_tag_text = "" + general_tag_text = "" + + # 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する + # First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold for i, p in enumerate(prob[4:]): if i < len(general_tags) and p >= args.general_threshold: tag_name = general_tags[i] - if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^ - tag_name = tag_name.replace("_", " ") if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += ", " + tag_name + general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) elif i >= len(general_tags) and p >= args.character_threshold: tag_name = character_tags[i - len(general_tags)] - if args.remove_underscore and len(tag_name) > 3: - tag_name = tag_name.replace("_", " ") if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += ", " + tag_name - combined_tags.append(tag_name) + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # insert to the beginning + combined_tags.insert(0, tag_name) + else: + combined_tags.append(tag_name) + + # 最初の4つはratingなのでargmaxで選ぶ + # First 4 labels are actually ratings: pick one with argmax + if args.use_rating_tags or args.use_rating_tags_as_last_tag: + ratings_probs = prob[:4] + rating_index = ratings_probs.argmax() + found_rating = rating_tags[rating_index] + + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + if args.use_rating_tags: + combined_tags.insert(0, found_rating) # insert to the beginning + else: + combined_tags.append(found_rating) + + # 一番最初に置くタグを指定する + # Always put some tags at the beginning + if always_first_tags is not None: + for tag in always_first_tags: + if tag in combined_tags: + combined_tags.remove(tag) + combined_tags.insert(0, tag) # 先頭のカンマを取る if len(general_tag_text) > 0: - general_tag_text = general_tag_text[2:] + general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: - character_tag_text = character_tag_text[2:] + character_tag_text = character_tag_text[len(caption_separator) :] caption_file = os.path.splitext(image_path)[0] + args.caption_extension - tag_text = ", ".join(combined_tags) + tag_text = caption_separator.join(combined_tags) if args.append_tags: # Check if file exists @@ -224,18 +310,22 @@ def main(args): existing_content = f.read().strip("\n") # Remove newlines # Split the content into tags and store them in a list - existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()] # Check and remove repeating tags in tag_text new_tags = [tag for tag in combined_tags if tag not in existing_tags] # Create new tag_text - tag_text = ", ".join(existing_tags + new_tags) + tag_text = caption_separator.join(existing_tags + new_tags) with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tRating tags: {rating_tag_text}") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -258,16 +348,14 @@ def main(args): continue image, image_path = data - if image is not None: - image = image.detach().numpy() - else: + if image is None: try: image = Image.open(image_path) if image.mode != "RGB": image = image.convert("RGB") image = preprocess_image(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -282,16 +370,18 @@ def main(args): if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - print("\nTag frequencies:") + print("Tag frequencies:") for tag, freq in sorted_tags: print(f"{tag}: {freq}") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ" + ) parser.add_argument( "--repo_id", type=str, @@ -305,9 +395,13 @@ def setup_parser() -> argparse.ArgumentParser: help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", ) parser.add_argument( - "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + "--force_download", + action="store_true", + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ" ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", type=int, @@ -320,8 +414,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", ) - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument( + "--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子" + ) + parser.add_argument( + "--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値" + ) parser.add_argument( "--general_threshold", type=float, @@ -334,22 +432,67 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", ) - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + parser.add_argument( + "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" + ) parser.add_argument( "--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", ) - parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument( + "--debug", action="store_true", help="debug mode" + ) parser.add_argument( "--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") - parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") - parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する" + ) + parser.add_argument( + "--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する" + ) + parser.add_argument( + "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" + ) + parser.add_argument( + "--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", + ) + parser.add_argument( + "--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", + ) + parser.add_argument( + "--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", + ) + parser.add_argument( + "--always_first_tags", + type=str, + default=None, + help="comma-separated list of tags to always put at the beginning, e.g. `1girl,1boy`" + + " / 必ず先頭に置くタグのカンマ区切りリスト、例 : `1girl,1boy`", + ) + parser.add_argument( + "--caption_separator", + type=str, + default=", ", + help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください", + ) + parser.add_argument( + "--tag_replacement", + type=str, + default=None, + help="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`" + + " / タグの置換を `置換元1,置換先1;置換元2,置換先2; ...`で指定する。`\` で `,` と `;` をエスケープできる。例: `tag1,tag2;tag3,tag4`", + ) + parser.add_argument( + "--character_tag_expand", + action="store_true", + help="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`" + + " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる", + ) return parser diff --git a/gen_img.py b/gen_img.py new file mode 100644 index 00000000..4fe89871 --- /dev/null +++ b/gen_img.py @@ -0,0 +1,3334 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import importlib.util +import sys +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch + +from library.device_utils import init_ipex, clean_memory, get_preferred_device + +init_ipex() + +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + logger.info("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + logger.info("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + logger.info("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + logger.info("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + logger.info("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + is_sdxl, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: InferSdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.is_sdxl = is_sdxl + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: Union[InferUNet2DConditionModel, InferSdxlUNet2DConditionModel] = unet + self.scheduler = scheduler + self.safety_checker = None + + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 + self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + self.gradual_latent: GradualLatent = None + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + def set_control_net_lllites(self, ctrl_net_lllites): + self.control_net_lllites = ctrl_net_lllites + + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + emb_normalize_mode: str = "original", + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + regional_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + self.is_sdxl, + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + emb_normalize_mode=emb_normalize_mode, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + self.is_sdxl, + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + emb_normalize_mode=emb_normalize_mode, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_net_lllites: + # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + if self.is_sdxl: + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if regional_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + if self.control_net_lllites: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_net_lllites: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_net_lllites: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_net_lllites) + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.warning("gradual_latent is not supported for this scheduler. Ignoring.") + logger.warning(f"{self.scheduler.__class__.__name__}") + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + if self.control_net_lllites: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + if self.control_nets and self.control_net_enabled: + if regional_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings + + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + text_emb_last, + ).sample + elif self.is_sdxl: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return latents + + latents = 1 / (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + logger.info(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + is_sdxl: bool, + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + is_sdxl: bool, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip: int = 1, + token_replacer=None, + device=None, + emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + + if (not skip_parsing) and (not skip_weighting): + if emb_normalize_mode == "abs": + previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + elif emb_normalize_mode == "none": + text_embeddings *= prompt_weights.unsqueeze(-1) + if uncond_prompt is not None: + uncond_embeddings *= uncond_weights.unsqueeze(-1) + + else: # "original" + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + logger.warning(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + raw_prompt: str + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +class ListPrompter: + def __init__(self, prompts: List[str]): + self.prompts = prompts + self.index = 0 + + def shuffle(self): + random.shuffle(self.prompts) + + def __len__(self): + return len(self.prompts) + + def __call__(self, *args, **kwargs): + if self.index >= len(self.prompts): + self.index = 0 # reset + return None + + prompt = self.prompts[self.index] + self.index += 1 + return prompt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + if args.v_parameterization and not args.v2: + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + # モデルを読み込む + if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt + use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + + # SDXLかどうかを判定する + is_sdxl = args.sdxl + if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する + if use_stable_diffusion_format: + # if file size > 5.5GB, sdxl + is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3 + else: + # if `text_encoder_2` subdirectory exists, sdxl + is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) + logger.info(f"SDXL: {is_sdxl}") + + if is_sdxl: + if args.clip_skip is None: + args.clip_skip = 2 + + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + text_encoders = [text_encoder1, text_encoder2] + else: + if args.clip_skip is None: + args.clip_skip = 2 if args.v2 else 1 + + if use_stable_diffusion_format: + logger.info("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + logger.info("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe + + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) + text_encoders = [text_encoder] + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + logger.info("additional VAE loaded") + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + logger.info("loading tokenizer") + if is_sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteSchedulerGL + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + logger.info("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + vae.eval() + + for text_encoder in text_encoders: + text_encoder.to(dtype).to(device) + text_encoder.eval() + unet.to(dtype).to(device) + unet.eval() + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + logger.info(f"network_merge: {network_merge}") + + for i, network_module in enumerate(args.network_module): + logger.info("import network module: {network_module}") + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + logger.info(f"load network weights from: {network_weight}") + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + logger.info(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs + ) + if network is None: + return + + mergeable = network.is_mergeable() + if network_merge and not mergeable: + logger.warning("network is not mergiable. ignore merge option.") + + if not mergeable or i >= network_merge: + # not merging + network.apply_to(text_encoders, unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + logger.info(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + logger.info("backup original weights") + network.backup_weights() + + networks.append(network) + network_default_muls.append(network_mul) + else: + network.merge_to(text_encoders, unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + logger.info("import upscaler module: {args.highres_fix_upscaler}") + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + logger.info("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + logger.info(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net_lllite = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net_lllite.apply_to() + control_net_lllite.load_state_dict(state_dict) + control_net_lllite.to(dtype).to(device) + control_net_lllite.set_batch_cond_only(False, False) + control_net_lllites.append((control_net_lllite, ratio)) + assert ( + len(control_nets) == 0 or len(control_net_lllites) == 0 + ), "ControlNet and ControlNet-LLLite cannot be used at the same time" + + if args.opt_channels_last: + logger.info(f"set optimizing: channels last") + for text_encoder in text_encoders: + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + + for cn in control_net_lllites: + cn.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + is_sdxl, + device, + vae, + text_encoders, + tokenizers, + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + pipe.set_control_net_lllites(control_net_lllites) + logger.info("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + if is_sdxl: + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + else: + embeds1 = next(iter(data.values())) + embeds2 = None + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizers[0].add_tokens(token_strings) + num_added_tokens2 = tokenizers[1].add_tokens(token_strings) if is_sdxl else 0 + assert num_added_tokens1 == num_vectors_per_token and ( + num_added_tokens2 == 0 or num_added_tokens2 == num_vectors_per_token + ), ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) + token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert not is_sdxl or ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizers[0]) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizers[0])}" + assert ( + not is_sdxl or len(tokenizers[1]) - 1 == token_ids2[-1] + ), f"token ids 2 is not end of tokenize: {len(tokenizers[1])}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + if is_sdxl: + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + if is_sdxl: + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoders[0].resize_token_embeddings(len(tokenizers[0])) + token_embeds1 = text_encoders[0].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + + if is_sdxl: + text_encoders[1].resize_token_embeddings(len(tokenizers[1])) + token_embeds2 = text_encoders[1].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + prompt_list = None + if args.from_file is not None: + logger.info(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + prompter = ListPrompter(prompt_list) + + elif args.from_module is not None: + + def load_module_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + logger.info(f"reading prompts from module: {args.from_module}") + prompt_module = load_module_from_path("prompt_module", args.from_module) + + prompter = prompt_module.get_prompter(args, pipe, networks) + + elif args.prompt is not None: + prompter = ListPrompter([args.prompt]) + + else: + prompter = None # interactive mode + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + logger.info(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + logger.info(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + logger.info(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + logger.info(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + logger.info(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + logger.info(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and prompter is None and not args.interactive: + logger.info("get prompts from images' metadata") + prompt_list = [] + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + prompter = ListPrompter(prompt_list) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + logger.info(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + logger.info(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + logger.info("use mask as region") + + size = None + for i, network in enumerate(networks): + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: + np_mask = np.array(mask_images[0]) + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + logger.info(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) + guide_images = None + else: + guide_images = None + + # 新しい乱数生成器を作成する + if args.seed is not None: + if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1: + # 引数のseedをそのまま使う + def fixed_seed(*args, **kwargs): + return args.seed + + seed_random = SimpleNamespace(randint=fixed_seed) + else: + seed_random = random.Random(args.seed) + else: + seed_random = random.Random() + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 if is_sdxl else 512 + if args.H is None: + args.H = 1024 if is_sdxl else 512 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") + if args.iter_same_seed: + iter_seed = seed_random.randint(0, 2**32 - 1) + else: + iter_seed = None + + # shuffle prompt list + if args.shuffle_prompts: + prompter.shuffle() + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + logger.info("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + logger.info("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + raw_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets or control_net_lllites: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + # TODO バッチから ds_ratio を取り出すべき + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + logger.info("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + emb_normalize_mode=args.emb_normalize_mode, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + if is_sdxl: + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.warning( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while True: + if args.interactive: + # interactive + valid = False + while not valid: + logger.info("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step) + if raw_prompt is None: + break + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + length = len(prompter) if hasattr(prompter, "__len__") else 0 + logger.info(f"prompt {prompt_index+1}/{length}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if args.iter_same_seed: + seed = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = seed_random.randint(0, 2**32 - 1) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets or control_net_lllites: # 複数件の場合あり + c = max(len(control_nets), len(control_net_lllites)) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + logger.info("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + + parser.add_argument( + "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" + ) + parser.add_argument( + "--v1", action="store_true", help="load Stable Diffusion v1.x model / Stable Diffusion 1.xのモデルを読み込む" + ) + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", + ) + parser.add_argument( + "--from_module", + type=str, + default=None, + help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む", + ) + parser.add_argument( + "--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数" + ) + parser.add_argument( + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument( + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument( + "--shuffle_prompts", + action="store_true", + help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", + ) + parser.add_argument( + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument( + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", + ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="layer number from bottom to use in CLIP, default is 1 for SD1/2, 2 for SDXL " + + "/ CLIPの後ろからn層目の出力を使う(デフォルトはSD1/2の場合1、SDXLの場合2)", + ) + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--emb_normalize_mode", + type=str, + default="original", + choices=["original", "none", "abs"], + help="embedding normalization mode / embeddingの正規化モード", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", + ) + + parser.add_argument( + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", + ) + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", + ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a596a049..2c40f1a0 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -64,17 +64,11 @@ import re import diffusers import numpy as np + import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device +init_ipex() -try: - import intel_extension_for_pytorch as ipex - - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, @@ -105,10 +99,17 @@ import library.train_util as train_util from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo -from library.original_unet import UNet2DConditionModel +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -144,12 +145,12 @@ USE_CUTOUTS = False def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -157,7 +158,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -173,7 +174,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -229,7 +230,7 @@ def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -285,7 +286,7 @@ def replace_vae_attn_to_xformers(): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -378,7 +379,7 @@ class PipelineLike: vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: InferUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, clip_model: CLIPModel, @@ -454,6 +455,8 @@ class PipelineLike: self.control_nets: List[ControlNetInfo] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids @@ -484,6 +487,14 @@ class PipelineLike: def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): @@ -689,7 +700,7 @@ class PipelineLike: do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -771,11 +782,11 @@ class PipelineLike: clip_text_input = prompt_tokens if clip_text_input.shape[1] > self.tokenizer.model_max_length: # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) + logger.info(f"trim text input {clip_text_input.shape}") clip_text_input = torch.cat( [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 ) - print("trimmed", clip_text_input.shape) + logger.info(f"trimmed {clip_text_input.shape}") for i, clip_prompt in enumerate(clip_prompts): if clip_prompt is not None: # clip_promptがあれば上書きする @@ -893,8 +904,7 @@ class PipelineLike: init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -958,7 +968,49 @@ class PipelineLike: else: text_emb_last = text_embeddings + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1052,8 +1104,7 @@ class PipelineLike: if vae_batch_size >= batch_size: image = self.vae.decode(latents).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -1540,7 +1591,9 @@ class PipelineLike: image_embeddings = self.vgg16_feat_model(image)["feat"] # バッチサイズが複数だと正しく動くかわからない - loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + loss = ( + (image_embeddings - guide_embeddings) ** 2 + ).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので grads = -torch.autograd.grad(loss, latents)[0] if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -1704,7 +1757,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1734,7 +1787,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -2046,7 +2099,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -2116,7 +2169,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -2131,6 +2184,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -2163,9 +2217,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -2175,10 +2229,10 @@ def main(args): use_stable_diffusion_format = os.path.isfile(args.ckpt) if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -2196,25 +2250,26 @@ def main(args): ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # # 置換するCLIPを読み込む # if args.replace_clip_l14_336: # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") + # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") + logger.info("prepare clip model") clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) else: clip_model = None if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") + logger.info("prepare resnet model") vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) else: vgg16_model = None @@ -2226,7 +2281,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if use_stable_diffusion_format: tokenizer = train_util.load_tokenizer(args) @@ -2249,7 +2304,7 @@ def main(args): scheduler_cls = EulerDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_euler_discrete elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": scheduler_cls = DPMSolverMultistepScheduler @@ -2285,7 +2340,7 @@ def main(args): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -2294,7 +2349,7 @@ def main(args): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -2325,11 +2380,11 @@ def main(args): # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") + logger.info("set clip_sample to True") scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -2352,13 +2407,20 @@ def main(args): vae = sli_vae del sli_vae vae.to(dtype).to(device) + vae.eval() text_encoder.to(dtype).to(device) unet.to(dtype).to(device) + + text_encoder.eval() + unet.eval() + if clip_model is not None: clip_model.to(dtype).to(device) + clip_model.eval() if vgg16_model is not None: vgg16_model.to(dtype).to(device) + vgg16_model.eval() # networkを組み込む if args.network_module: @@ -2375,7 +2437,7 @@ def main(args): network_merge = 0 for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -2393,7 +2455,7 @@ def main(args): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -2401,7 +2463,7 @@ def main(args): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs @@ -2411,20 +2473,20 @@ def main(args): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -2438,7 +2500,7 @@ def main(args): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -2447,7 +2509,7 @@ def main(args): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -2464,7 +2526,7 @@ def main(args): control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last) @@ -2496,11 +2558,38 @@ def main(args): args.vgg16_guidance_layer, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -2522,7 +2611,9 @@ def main(args): embeds = next(iter(data.values())) if type(embeds) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + raise ValueError( + f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}" + ) num_vectors_per_token = embeds.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] @@ -2535,7 +2626,7 @@ def main(args): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") assert ( min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 ), f"token ids is not ordered" @@ -2594,7 +2685,7 @@ def main(args): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") # if num_vectors_per_token > 1: pipe.add_token_replacement(token_ids[0], token_ids) @@ -2619,10 +2710,10 @@ def main(args): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -2648,7 +2739,7 @@ def main(args): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2664,24 +2755,24 @@ def main(args): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") + logger.info("get prompts from images' meta data") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -2710,17 +2801,17 @@ def main(args): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2745,14 +2836,16 @@ def main(args): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.info( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -2778,7 +2871,7 @@ def main(args): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # shuffle prompt list @@ -2794,7 +2887,7 @@ def main(args): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + 0.5) @@ -2820,7 +2913,7 @@ def main(args): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2866,13 +2959,14 @@ def main(args): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), ) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2903,11 +2997,16 @@ def main(args): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2971,7 +3070,7 @@ def main(args): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2999,8 +3098,8 @@ def main(args): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -3016,6 +3115,8 @@ def main(args): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) if args.use_original_file_name and init_images is not None: if type(init_images) is list: @@ -3038,7 +3139,9 @@ def main(args): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.info( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -3051,7 +3154,8 @@ def main(args): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -3085,40 +3189,55 @@ def main(args): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -3127,25 +3246,25 @@ def main(args): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -3153,12 +3272,120 @@ def main(args): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.info(f"Exception in parsing / 解析エラー: {parg}") + logger.info(ex) + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + logger.info(f'{unsharp_params}') + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) # prepare seed if seeds is not None: # given in prompt @@ -3170,7 +3397,7 @@ def main(args): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.info("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seed = iter_seed @@ -3180,7 +3407,7 @@ def main(args): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -3196,7 +3423,7 @@ def main(args): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.info( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -3212,9 +3439,9 @@ def main(args): guide_image = guide_images[global_step % len(guide_images)] elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: if prev_image is None: - print("Generate 1st image without guide image.") + logger.info("Generate 1st image without guide image.") else: - print("Use previous image as guide image.") + logger.info("Use previous image as guide image.") guide_image = prev_image if regional_network: @@ -3227,7 +3454,9 @@ def main(args): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -3256,22 +3485,31 @@ def main(args): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + add_logging_arguments(parser) + + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -3283,7 +3521,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -3337,9 +3577,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -3375,25 +3620,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -3415,7 +3681,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -3456,7 +3724,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -3465,7 +3736,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -3473,7 +3746,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -3488,14 +3764,21 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", ) parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") parser.add_argument( @@ -3509,6 +3792,69 @@ def setup_parser() -> argparse.ArgumentParser: # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + return parser @@ -3516,4 +3862,5 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7..d75d03b0 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -1,465 +1,530 @@ import argparse from dataclasses import ( - asdict, - dataclass, + asdict, + dataclass, ) import functools import random from textwrap import dedent, indent import json from pathlib import Path + # from toolz import curry from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, + List, + Optional, + Sequence, + Tuple, + Union, ) import toml import voluptuous from voluptuous import ( - Any, - ExactSequence, - MultipleInvalid, - Object, - Required, - Schema, + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, ) from transformers import CLIPTokenizer from . import train_util from .train_util import ( - DreamBoothSubset, - FineTuningSubset, - ControlNetSubset, - DreamBoothDataset, - FineTuningDataset, - ControlNetDataset, - DatasetGroup, + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, ) +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) + # TODO: inherit Params class in Subset, Dataset + @dataclass class BaseSubsetParams: - image_dir: Optional[str] = None - num_repeats: int = 1 - shuffle_caption: bool = False - keep_tokens: int = 0 - color_aug: bool = False - flip_aug: bool = False - face_crop_aug_range: Optional[Tuple[float, float]] = None - random_crop: bool = False - caption_prefix: Optional[str] = None - caption_suffix: Optional[str] = None - caption_dropout_rate: float = 0.0 - caption_dropout_every_n_epochs: int = 0 - caption_tag_dropout_rate: float = 0.0 - token_warmup_min: int = 1 - token_warmup_step: float = 0 + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + @dataclass class DreamBoothSubsetParams(BaseSubsetParams): - is_reg: bool = False - class_tokens: Optional[str] = None - caption_extension: str = ".caption" + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + cache_info: bool = False + @dataclass class FineTuningSubsetParams(BaseSubsetParams): - metadata_file: Optional[str] = None + metadata_file: Optional[str] = None + @dataclass class ControlNetSubsetParams(BaseSubsetParams): - conditioning_data_dir: str = None - caption_extension: str = ".caption" + conditioning_data_dir: str = None + caption_extension: str = ".caption" + cache_info: bool = False + @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class ControlNetDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: - params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + @dataclass class DatasetBlueprint: - is_dreambooth: bool - is_controlnet: bool - params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] - subsets: Sequence[SubsetBlueprint] + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + @dataclass class DatasetGroupBlueprint: - datasets: Sequence[DatasetBlueprint] + datasets: Sequence[DatasetBlueprint] + + @dataclass class Blueprint: - dataset_group: DatasetGroupBlueprint + dataset_group: DatasetGroupBlueprint class ConfigSanitizer: - # @curry - @staticmethod - def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: - Schema(ExactSequence([klass, klass]))(value) - return tuple(value) + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) - # @curry - @staticmethod - def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: - Schema(Any(klass, ExactSequence([klass, klass])))(value) - try: - Schema(klass)(value) - return (value, value) - except: - return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) - # subset schema - SUBSET_ASCENDABLE_SCHEMA = { - "color_aug": bool, - "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), - "flip_aug": bool, - "num_repeats": int, - "random_crop": bool, - "shuffle_caption": bool, - "keep_tokens": int, - "token_warmup_min": int, - "token_warmup_step": Any(float,int), - "caption_prefix": str, - "caption_suffix": str, - } - # DO means DropOut - DO_SUBSET_ASCENDABLE_SCHEMA = { - "caption_dropout_every_n_epochs": int, - "caption_dropout_rate": Any(float, int), - "caption_tag_dropout_rate": Any(float, int), - } - # DB means DreamBooth - DB_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - "class_tokens": str, - } - DB_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - "is_reg": bool, - } - # FT means FineTuning - FT_SUBSET_DISTINCT_SCHEMA = { - Required("metadata_file"): str, - "image_dir": str, - } - CN_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - } - CN_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - Required("conditioning_data_dir"): str, - } + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "secondary_separator": str, + "enable_wildcard": bool, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + "cache_info": bool, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "cache_info": bool, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } - # datasets schema - DATASET_ASCENDABLE_SCHEMA = { - "batch_size": int, - "bucket_no_upscale": bool, - "bucket_reso_steps": int, - "enable_bucket": bool, - "max_bucket_reso": int, - "min_bucket_reso": int, - "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), - } + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, + } - # options handled by argparse but not handled by user config - ARGPARSE_SPECIFIC_SCHEMA = { - "debug_dataset": bool, - "max_token_length": Any(None, int), - "prior_loss_weight": Any(float, int), - } - # for handling default None value of argparse - ARGPARSE_NULLABLE_OPTNAMES = [ - "face_crop_aug_range", - "resolution", - ] - # prepare map because option name may differ among argparse and user config - ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { - "train_batch_size": "batch_size", - "dataset_repeats": "num_repeats", - } + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) - self.db_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_DISTINCT_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.ft_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.FT_SUBSET_DISTINCT_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.cn_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_DISTINCT_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.db_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.db_subset_schema]}, - ) + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) - self.ft_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.ft_subset_schema]}, - ) + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) - self.cn_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.cn_subset_schema]}, - ) + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) - if support_dreambooth and support_finetuning: - def validate_flex_dataset(dataset_config: dict): - subsets_config = dataset_config.get("subsets", []) + if support_dreambooth and support_finetuning: - if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): - return Schema(self.cn_dataset_schema)(dataset_config) - # check dataset meets FT style - # NOTE: all FT subsets should have "metadata_file" - elif all(["metadata_file" in subset for subset in subsets_config]): - return Schema(self.ft_dataset_schema)(dataset_config) - # check dataset meets DB style - # NOTE: all DB subsets should have no "metadata_file" - elif all(["metadata_file" not in subset for subset in subsets_config]): - return Schema(self.db_dataset_schema)(dataset_config) - else: - raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) - self.dataset_schema = validate_flex_dataset - elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema - elif support_finetuning: - self.dataset_schema = self.ft_dataset_schema - elif support_controlnet: - self.dataset_schema = self.cn_dataset_schema + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) - self.general_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, - self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + if support_controlnet: + self.dataset_schema = self.cn_dataset_schema + else: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema - self.user_config_validator = Schema({ - "general": self.general_schema, - "datasets": [self.dataset_schema], - }) + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.argparse_schema = self.__merge_dict( - self.general_schema, - self.ARGPARSE_SPECIFIC_SCHEMA, - {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, - {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, - ) + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) - self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) - def sanitize_user_config(self, user_config: dict) -> dict: - try: - return self.user_config_validator(user_config) - except MultipleInvalid: - # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") - raise + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) - # NOTE: In nature, argument parser result is not needed to be sanitize - # However this will help us to detect program bug - def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: - try: - return self.argparse_config_validator(argparse_namespace) - except MultipleInvalid: - # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") - raise + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") + raise - # NOTE: value would be overwritten by latter dict if there is already the same key - @staticmethod - def __merge_dict(*dict_list: dict) -> dict: - merged = {} - for schema in dict_list: - # merged |= schema - for k, v in schema.items(): - merged[k] = v - return merged + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged class BlueprintGenerator: - BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { - } + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} - def __init__(self, sanitizer: ConfigSanitizer): - self.sanitizer = sanitizer + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer - # runtime_params is for parameters which is only configurable on runtime, such as tokenizer - def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: - sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) - sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) - # convert argparse namespace to dict like config - # NOTE: it is ok to have extra entries in dict - optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME - argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } - general_config = sanitized_user_config.get("general", {}) + general_config = sanitized_user_config.get("general", {}) - dataset_blueprints = [] - for dataset_config in sanitized_user_config.get("datasets", []): - # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets - subsets = dataset_config.get("subsets", []) - is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) - if is_controlnet: - subset_params_klass = ControlNetSubsetParams - dataset_params_klass = ControlNetDatasetParams - elif is_dreambooth: - subset_params_klass = DreamBoothSubsetParams - dataset_params_klass = DreamBoothDatasetParams - else: - subset_params_klass = FineTuningSubsetParams - dataset_params_klass = FineTuningDatasetParams + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams - subset_blueprints = [] - for subset_config in subsets: - params = self.generate_params_by_fallbacks(subset_params_klass, - [subset_config, dataset_config, general_config, argparse_config, runtime_params]) - subset_blueprints.append(SubsetBlueprint(params)) + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) - params = self.generate_params_by_fallbacks(dataset_params_klass, - [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) - dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) - return Blueprint(dataset_group_blueprint) + return Blueprint(dataset_group_blueprint) - @staticmethod - def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): - name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME - search_value = BlueprintGenerator.search_value - default_params = asdict(param_klass()) - param_names = default_params.keys() + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() - params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} - return param_klass(**params) + return param_klass(**params) - @staticmethod - def search_value(key: str, fallbacks: Sequence[dict], default_value = None): - for cand in fallbacks: - value = cand.get(key) - if value is not None: - return value + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value - return default_value + return default_value def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset - elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset - else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) - datasets.append(dataset) + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ [Dataset {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} - """) + network_multiplier: {dataset.network_multiplier} + """ + ) - if dataset.enable_bucket: - info += indent(dedent(f"""\ + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") - else: - info += "\n" + \n""" + ), + " ", + ) + else: + info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + secondary_separator: {subset.secondary_separator} + enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} @@ -471,147 +536,179 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - """), " ") + """ + ), + " ", + ) - if is_dreambooth: - info += indent(dedent(f"""\ + if is_dreambooth: + info += indent( + dedent( + f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ metadata_file: {subset.metadata_file} - \n"""), " ") + \n""" + ), + " ", + ) - print(info) + logger.info(f"{info}") - # make buckets first because it determines the length of dataset - # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no - for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") - dataset.make_buckets() - dataset.set_seed(seed) + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + logger.info(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) - return DatasetGroup(datasets) + return DatasetGroup(datasets) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): - def extract_dreambooth_params(name: str) -> Tuple[int, str]: - tokens = name.split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") - return 0, "" - caption_by_folder = '_'.join(tokens[1:]) - return n_repeats, caption_by_folder + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder - def generate(base_dir: Optional[str], is_reg: bool): - if base_dir is None: - return [] + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config subsets_config = [] - for subdir in base_dir.iterdir(): - if not subdir.is_dir(): - continue - - num_repeats, class_tokens = extract_dreambooth_params(subdir.name) - if num_repeats < 1: - continue - - subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir, False) - subsets_config += generate(reg_data_dir, True) - return subsets_config +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] -def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): - def generate(base_dir: Optional[str]): - if base_dir is None: - return [] + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + return subsets_config subsets_config = [] - subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir) - - return subsets_config - def load_user_config(file: str) -> dict: - file: Path = Path(file) - if not file.is_file(): - raise ValueError(f"file not found / ファイルが見つかりません: {file}") + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") - if file.name.lower().endswith('.json'): - try: - with open(file, 'r') as f: - config = json.load(f) - except Exception: - print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - elif file.name.lower().endswith('.toml'): - try: - config = toml.load(file) - except Exception: - print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - else: - raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + if file.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file.name.lower().endswith(".toml"): + try: + config = toml.load(file) + except Exception: + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + + return config - return config # for config test if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--support_dreambooth", action="store_true") - parser.add_argument("--support_finetuning", action="store_true") - parser.add_argument("--support_controlnet", action="store_true") - parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("dataset_config") - config_args, remain = parser.parse_known_args() + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() - parser = argparse.ArgumentParser() - train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) - train_util.add_training_arguments(parser, config_args.support_dreambooth) - argparse_namespace = parser.parse_args(remain) - train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + logger.info("[argparse_namespace]") + logger.info(f"{vars(argparse_namespace)}") - user_config = load_user_config(config_args.dataset_config) + user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + logger.info("") + logger.info("[user_config]") + logger.info(f"{user_config}") - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) - sanitized_user_config = sanitizer.sanitize_user_config(user_config) + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f"{sanitized_user_config}") - blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + logger.info("") + logger.info("[blueprint]") + logger.info(f"{blueprint}") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 28b625d3..406e0e36 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,6 +3,12 @@ import argparse import random import re from typing import List, Optional, Union +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): @@ -21,7 +27,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR - print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") def enforce_zero_terminal_snr(betas): # Convert betas to alphas_bar_sqrt @@ -49,18 +55,21 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - # print("original:", noise_scheduler.betas) - # print("fixed:", betas) + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") noise_scheduler.betas = betas noise_scheduler.alphas = alphas noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight return loss @@ -76,23 +85,25 @@ def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info - # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") return scale def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): scale = get_snr_scale(timesteps, noise_scheduler) - # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss + def apply_debiased_estimation(loss, timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1/torch.sqrt(snr_t) + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss + # TODO train_utilと分散しているのでどちらかに寄せる @@ -265,7 +276,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): tokens.append(text_token) weights.append(text_weight) if truncated: - print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -468,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise +def apply_masked_loss(loss, batch): + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + return loss + + """ ########################################## # Perlin Noise diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py new file mode 100644 index 00000000..99a7b2b3 --- /dev/null +++ b/library/deepspeed_utils.py @@ -0,0 +1,139 @@ +import os +import argparse +import torch +from accelerate import DeepSpeedPlugin, Accelerator + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_deepspeed_arguments(parser: argparse.ArgumentParser): + # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed + parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") + parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") + parser.add_argument( + "--offload_optimizer_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", + ) + parser.add_argument( + "--offload_optimizer_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--zero3_init_flag", + action="store_true", + help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--zero3_save_16bit_model", + action="store_true", + help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--fp16_master_weights_and_gradients", + action="store_true", + help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", + ) + + +def prepare_deepspeed_args(args: argparse.Namespace): + if not args.deepspeed: + return + + # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + args.max_data_loader_n_workers = 1 + + +def prepare_deepspeed_plugin(args: argparse.Namespace): + if not args.deepspeed: + return None + + try: + import deepspeed + except ImportError as e: + logger.error( + "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" + ) + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, + offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, + offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, + zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + deepspeed_plugin.deepspeed_config["train_batch_size"] = ( + args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) + ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True + logger.info("[DeepSpeed] full fp16 enable.") + else: + logger.info( + "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." + ) + + if args.offload_optimizer_device is not None: + logger.info("[DeepSpeed] start to manually build cpu_adam.") + deepspeed.ops.op_builder.CPUAdamBuilder().load() + logger.info("[DeepSpeed] building cpu_adam done.") + + return deepspeed_plugin + + +# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. +def prepare_deepspeed_model(args: argparse.Namespace, **models): + # remove None from models + models = {k: v for k, v in models.items() if v is not None} + + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance( + model, torch.nn.Module + ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 00000000..8823c5d9 --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,84 @@ +import functools +import gc + +import torch + +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + +try: + import intel_extension_for_pytorch as ipex # noqa + + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + + +def clean_memory(): + gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device + + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from library.ipex import ipex_init + + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 376fdb1e..57b19d98 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -4,7 +4,10 @@ from pathlib import Path import argparse import os from library.utils import fire_in_thread - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( @@ -33,9 +36,9 @@ def upload( try: api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので - print("===========================================") - print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) @@ -56,9 +59,9 @@ def upload( path_in_repo=path_in_repo, ) except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので - print("===========================================") - print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 43accd9f..e5aba693 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -4,172 +4,177 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks -from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.nn.Module.cuda = torch.nn.Module.xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - #AMP: - torch.cuda.amp = torch.xpu.amp - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 2024 + ipex._C._DeviceProperties.minor = 0 - #Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 - if hasattr(torch.xpu, 'getDeviceIdListForCard'): - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard - else: - torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card - torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - attention_init() - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 84848b6a..d989ad53 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,45 +1,98 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -original_torch_bmm = torch.bmm -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = input.element_size() - slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size block_size = batch_size_attention * slice_block_size split_slice_size = batch_size_attention - if block_size > 4: + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: do_split = True - #Find something divisible with the input_tokens - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention split_2_slice_size = input_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply - do_split_2 = True - #Find something divisible with the input_tokens - while (split_2_slice_size * slice_block_size2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False + split_3_slice_size = mat2_atten_shape + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + if input.device.type != "xpu": + return original_torch_bmm(input, mat2, out=out) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM if do_split: + batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -48,62 +101,41 @@ def torch_bmm(input, mat2, *, out=None): for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) + if do_split_3: + for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + out=out + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) else: hidden_states[start_idx:end_idx] = original_torch_bmm( input[start_idx:end_idx], mat2[start_idx:end_idx], out=out ) + torch.xpu.synchronize(input.device) else: return original_torch_bmm(input, mat2, out=out) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - if len(query.shape) == 3: - batch_size_attention, query_tokens, shape_four = query.shape - shape_one = 1 - no_shape_one = True - else: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape - no_shape_one = False - - block_multiply = query.element_size() - slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - if block_size > 4: - do_split = True - #Find something divisible with the shape_one - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_2_slice_size = query_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply - do_split_2 = True - #Find something divisible with the batch_size_attention - while (split_2_slice_size * slice_block_size2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): + if query.device.type != "xpu": + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + # Slice SDPA if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -112,46 +144,34 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - if no_shape_one: + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + else: hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( query[start_idx:end_idx, start_idx_2:end_idx_2], key[start_idx:end_idx, start_idx_2:end_idx_2], value[start_idx:end_idx, start_idx_2:end_idx_2], attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx, start_idx_2:end_idx_2], - key[:, start_idx:end_idx, start_idx_2:end_idx_2], - value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - if no_shape_one: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx], - key[:, start_idx:end_idx], - value[:, start_idx:end_idx], - attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + torch.xpu.synchronize(query.device) else: - return original_scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal - ) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) return hidden_states - -def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: - torch.bmm = torch_bmm - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 005ee49f..732a1856 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,10 +1,62 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.21.1 # pylint: disable=import-error +import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +@cache +def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + if slice_size is not None: + batch_size_attention = slice_size + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if query_device_type != "xpu": + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + residual = hidden_states input_ndim = hidden_states.ndim @@ -54,49 +108,62 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - block_multiply = query.element_size() - slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply - block_size = query_tokens * slice_block_size - split_2_slice_size = query_tokens - if block_size > 4: - do_split_2 = True - #Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size if do_split_2: for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - + del query_slice + del key_slice + del attn_mask_slice attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + #################################################################### hidden_states = attn.batch_to_head_dim(hidden_states) @@ -115,6 +182,131 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods return hidden_states + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None, + temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) + + if do_split: + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + torch.xpu.synchronize(query.device) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + #################################################################### + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + def ipex_diffusers(): #ARC GPUs can't allocate more than 4GB to a single block: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py index 53021210..6eb56bc2 100644 --- a/library/ipex/gradscaler.py +++ b/library/ipex/gradscaler.py @@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long +device_supports_fp64 = torch.xpu.has_fp64_dtype() OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state @@ -96,7 +97,10 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) found_inf = torch.full( (1,), 0.0, dtype=torch.float32, device=self._scale.device ) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 77ed5419..d3cef827 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,67 +1,14 @@ -import contextlib -import importlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np + +device_supports_fp64 = torch.xpu.has_fp64_dtype() # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return -class CondFunc: # pylint: disable=missing-class-docstring - def __new__(cls, orig_func, sub_func, cond_func): - self = super(CondFunc, cls).__new__(cls) - if isinstance(orig_func, str): - func_path = orig_func.split('.') - for i in range(len(func_path)-1, -1, -1): - try: - resolved_obj = importlib.import_module('.'.join(func_path[:i])) - break - except ImportError: - pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) - self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) - -_utils = torch.utils.data._utils -def _shutdown_workers(self): - if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: - return - if hasattr(self, "_shutdown") and not self._shutdown: - self._shutdown = True - try: - if hasattr(self, '_pin_memory_thread'): - self._pin_memory_thread_done_event.set() - self._worker_result_queue.put((None, None)) - self._pin_memory_thread.join() - self._worker_result_queue.cancel_join_thread() - self._worker_result_queue.close() - self._workers_done_event.set() - for worker_id in range(len(self._workers)): - if self._persistent_workers or self._workers_status[worker_id]: - self._mark_worker_as_unavailable(worker_id, shutdown=True) - for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) - for q in self._index_queues: # pylint: disable=invalid-name - q.cancel_join_thread() - q.close() - finally: - if self._worker_pids_set: - torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) - self._worker_pids_set = False - for w in self._workers: # pylint: disable=invalid-name - if w.is_alive(): - w.terminate() - class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: @@ -69,7 +16,11 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() + +@property +def is_cuda(self): + return self.device.type == 'xpu' or self.device.type == 'cuda' def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) @@ -77,28 +28,21 @@ def check_device(device): def return_xpu(device): return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" -def ipex_no_cuda(orig_func, *args, **kwargs): - torch.cuda.is_available = lambda: False - orig_func(*args, **kwargs) - torch.cuda.is_available = torch.xpu.is_available -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +# Autocast +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) - -original_torch_cat = torch.cat -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) +# Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None: + if antialias or align_corners is not None or mode == 'bicubic': return_device = tensor.device return_dtype = tensor.dtype return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, @@ -107,90 +51,263 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) -original_linalg_solve = torch.linalg.solve -def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name - if A.device != torch.device("cpu") or B.device != torch.device("cpu"): - return_device = A.device - return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + +# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): +original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) +def from_numpy(ndarray): + if ndarray.dtype == float: + return original_from_numpy(ndarray.astype('float32')) else: - return original_linalg_solve(A, B, *args, **kwargs) + return original_from_numpy(ndarray) +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +else: + # 32 bit attention workarounds for Alchemist: + try: + from .attention import torch_bmm_32_bit as original_torch_bmm + from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + except Exception: # pylint: disable=broad-exception-caught + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + + +# Data Type Errors: +@wraps(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +@wraps(torch.nn.functional.scaled_dot_product_attention) +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + +# A1111 FP16 +original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) +def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) + +# A1111 BF16 +original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) +def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps) + +# Training +original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) +def functional_linear(input, weight, bias=None): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_linear(input, weight, bias=bias) + +original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) +def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +# A1111 Embedding BF16 +original_torch_cat = torch.cat +@wraps(torch.cat) +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) + +# SwinIR BF16: +original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) +def functional_pad(input, pad, mode='constant', value=None): + if mode == 'reflect' and input.dtype == torch.bfloat16: + return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) + else: + return original_functional_pad(input, pad, mode=mode, value=value) + + +original_torch_tensor = torch.tensor +@wraps(torch.tensor) +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): + if check_device(device): + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) + +original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) +def Tensor_to(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_to(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_to(self, device, *args, **kwargs) + +original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) +def Tensor_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_cuda(self, device, *args, **kwargs) + +original_Tensor_pin_memory = torch.Tensor.pin_memory +@wraps(torch.Tensor.pin_memory) +def Tensor_pin_memory(self, device=None, *args, **kwargs): + if device is None: + device = "xpu" + if check_device(device): + return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_pin_memory(self, device, *args, **kwargs) + +original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) +def UntypedStorage_init(*args, device=None, **kwargs): + if check_device(device): + return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) + else: + return original_UntypedStorage_init(*args, device=device, **kwargs) + +original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) +def UntypedStorage_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_UntypedStorage_cuda(self, device, *args, **kwargs) + +original_torch_empty = torch.empty +@wraps(torch.empty) +def torch_empty(*args, device=None, **kwargs): + if check_device(device): + return original_torch_empty(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_empty(*args, device=device, **kwargs) + +original_torch_randn = torch.randn +@wraps(torch.randn) +def torch_randn(*args, device=None, dtype=None, **kwargs): + if dtype == bytes: + dtype = None + if check_device(device): + return original_torch_randn(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_randn(*args, device=device, **kwargs) + +original_torch_ones = torch.ones +@wraps(torch.ones) +def torch_ones(*args, device=None, **kwargs): + if check_device(device): + return original_torch_ones(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_ones(*args, device=device, **kwargs) + +original_torch_zeros = torch.zeros +@wraps(torch.zeros) +def torch_zeros(*args, device=None, **kwargs): + if check_device(device): + return original_torch_zeros(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_zeros(*args, device=device, **kwargs) + +original_torch_linspace = torch.linspace +@wraps(torch.linspace) +def torch_linspace(*args, device=None, **kwargs): + if check_device(device): + return original_torch_linspace(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_linspace(*args, device=device, **kwargs) + +original_torch_Generator = torch.Generator +@wraps(torch.Generator) +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +original_torch_load = torch.load +@wraps(torch.load) +def torch_load(f, map_location=None, *args, **kwargs): + if map_location is None: + map_location = "xpu" + if check_device(map_location): + return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs) + else: + return original_torch_load(f, *args, map_location=map_location, **kwargs) + + +# Hijack Functions: def ipex_hijacks(): - CondFunc('torch.Tensor.to', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.Tensor.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.empty', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), - lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) - CondFunc('torch.randn', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.ones', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.zeros', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.linspace', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) + torch.tensor = torch_tensor + torch.Tensor.to = Tensor_to + torch.Tensor.cuda = Tensor_cuda + torch.Tensor.pin_memory = Tensor_pin_memory + torch.UntypedStorage.__init__ = UntypedStorage_init + torch.UntypedStorage.cuda = UntypedStorage_cuda + torch.empty = torch_empty + torch.randn = torch_randn + torch.ones = torch_ones + torch.zeros = torch_zeros + torch.linspace = torch_linspace + torch.Generator = torch_Generator + torch.load = torch_load - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - - CondFunc('torch.batch_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - CondFunc('torch.instance_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - - #Functions with dtype errors: - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) - - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): - if not torch.xpu.has_fp64_dtype(): - CondFunc('torch.from_numpy', - lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), - lambda orig_func, ndarray: ndarray.dtype == float) - - #Broken functions when torch.cuda.is_available is True: - CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', - lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), - lambda orig_func, *args, **kwargs: True) - - #Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers - torch.nn.DataParallel = DummyDataParallel - torch.autocast = ipex_autocast - torch.cat = torch_cat - torch.linalg.solve = linalg_solve - torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context + torch.nn.DataParallel = DummyDataParallel + torch.UntypedStorage.is_cuda = is_cuda + torch.amp.autocast_mode.autocast.__init__ = autocast_init + + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + torch.nn.functional.group_norm = functional_group_norm + torch.nn.functional.layer_norm = functional_layer_norm + torch.nn.functional.linear = functional_linear + torch.nn.functional.conv2d = functional_conv2d + torch.nn.functional.interpolate = interpolate + torch.nn.functional.pad = functional_pad + + torch.bmm = torch_bmm + torch.cat = torch_cat + if not device_supports_fp64: + torch.from_numpy = from_numpy + torch.as_tensor = as_tensor diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 9dce91a7..5717233d 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -9,7 +9,7 @@ import numpy as np import PIL.Image import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline @@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - try: from diffusers.utils import PIL_INTERPOLATION except ImportError: @@ -520,6 +519,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + image_encoder: CLIPVisionModelWithProjection = None, clip_skip: int = 1, ): super().__init__( @@ -531,32 +531,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker, + image_encoder=image_encoder, ) - self.clip_skip = clip_skip + self.custom_clip_skip = clip_skip self.__init__additional__() - # else: - # def __init__( - # self, - # vae: AutoencoderKL, - # text_encoder: CLIPTextModel, - # tokenizer: CLIPTokenizer, - # unet: UNet2DConditionModel, - # scheduler: SchedulerMixin, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - # ): - # super().__init__( - # vae=vae, - # text_encoder=text_encoder, - # tokenizer=tokenizer, - # unet=unet, - # scheduler=scheduler, - # safety_checker=safety_checker, - # feature_extractor=feature_extractor, - # ) - # self.__init__additional__() - def __init__additional__(self): if not hasattr(self, "vae_scale_factor"): setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) @@ -624,7 +603,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, + clip_skip=self.custom_clip_skip, ) bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) @@ -646,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height, width) + logger.info(f'{height} {width}') raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( diff --git a/library/model_util.py b/library/model_util.py index 00a3c049..be410a02 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -3,19 +3,20 @@ import math import os + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex +init_ipex() + import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -571,9 +572,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - # support checkpoint without position_ids (invalid checkpoint) - if "text_model.embeddings.position_ids" not in text_model_dict: - text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + # remove position_ids for newer transformer, which causes error :( + if "text_model.embeddings.position_ids" in text_model_dict: + text_model_dict.pop("text_model.embeddings.position_ids") return text_model_dict @@ -947,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -1005,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt unet = UNet2DConditionModel(**unet_config).to(device) info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) + logger.info(f"loading u-net: {info}") # Convert the VAE model. vae_config = create_vae_diffusers_config() @@ -1013,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) + logger.info(f"loading vae: {info}") # convert text_model if v2: @@ -1047,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt # logging.set_verbosity_error() # don't show annoying warning # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # logging.set_verbosity_warning() - # print(f"config: {text_model.config}") + # logger.info(f"config: {text_model.config}") cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, @@ -1070,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt ) text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + logger.info(f"loading text encoder: {info}") return text_model, vae, unet @@ -1145,7 +1146,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals # 最後の層などを捏造するか if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") + logger.info("make dummy weights for resblock.23, text_projection and logit scale.") keys = list(new_sd.keys()) for key in keys: if key.startswith("transformer.resblocks.22."): @@ -1242,8 +1243,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod if vae is None: vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + # original U-Net cannot be saved, so we need to convert it to the Diffusers version + # TODO this consumes a lot of memory + diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") + diffusers_unet.load_state_dict(unet.state_dict()) + pipeline = StableDiffusionPipeline( - unet=unet, + unet=diffusers_unet, text_encoder=text_encoder, vae=vae, scheduler=scheduler, @@ -1259,14 +1265,14 @@ VAE_PREFIX = "first_stage_model." def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") + logger.info(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): # Diffusers local/remote try: vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") + logger.error(f"exception occurs in loading vae: {e}") + logger.error("retry with subfolder='vae'") vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae @@ -1307,19 +1313,19 @@ def load_vae(vae_id, dtype): def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) + max_area = max_width * max_height resos = set() - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) + width = int(math.sqrt(max_area) // divisible) * divisible + resos.add((width, width)) - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) + width = min_size + while width <= max_size: + height = min(max_size, int((max_area // width) // divisible) * divisible) + if height >= min_size: + resos.add((width, height)) + resos.add((height, width)) # # make additional resos # if width >= height and width - divisible >= min_size: @@ -1329,7 +1335,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) # resos.add((width, height - divisible)) # resos.add((height - divisible, width)) - size += divisible + width += divisible resos = list(resos) resos.sort() @@ -1338,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) if __name__ == "__main__": resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) + logger.info(f"{len(resos)}") + logger.info(f"{resos}") aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) + logger.info(f"{aspect_ratios}") ars = set() for ar in aspect_ratios: if ar in ars: - print("error! duplicate ar:", ar) + logger.error(f"error! duplicate ar: {ar}") ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py index 240b8595..e944ff22 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -113,6 +113,10 @@ import torch from torch import nn from torch.nn import functional as F from einops import rearrange +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -361,6 +365,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class SampleOutput: def __init__(self, sample): self.sample = sample @@ -569,6 +590,9 @@ class CrossAttention(nn.Module): self.use_memory_efficient_attention_mem_eff = False self.use_sdpa = False + # Attention processor + self.processor = None + def set_use_memory_efficient_attention(self, xformers, mem_eff): self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_mem_eff = mem_eff @@ -590,7 +614,28 @@ class CrossAttention(nn.Module): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def forward(self, hidden_states, context=None, mask=None): + def set_processor(self): + return self.processor + + def get_processor(self): + return self.processor + + def forward(self, hidden_states, context=None, mask=None, **kwargs): + if self.processor is not None: + ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) = translate_attention_names_from_diffusers( + hidden_states=hidden_states, context=context, mask=mask, **kwargs + ) + return self.processor( + attn=self, + hidden_states=hidden_states, + encoder_hidden_states=context, + attention_mask=mask, + **kwargs + ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) if self.use_memory_efficient_attention_mem_eff: @@ -703,6 +748,21 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out +def translate_attention_names_from_diffusers( + hidden_states: torch.FloatTensor, + context: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + # HF naming + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None +): + # translate from hugging face diffusers + context = context if context is not None else encoder_hidden_states + + # translate from hugging face diffusers + mask = mask if mask is not None else attention_mask + + return hidden_states, context, mask # feedforward class GEGLU(nn.Module): @@ -1130,6 +1190,7 @@ class UpBlock2D(nn.Module): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1205,9 +1266,9 @@ class CrossAttnUpBlock2D(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) - def set_use_sdpa(self, spda): + def set_use_sdpa(self, sdpa): for attn in self.attentions: - attn.set_use_sdpa(spda) + attn.set_use_sdpa(sdpa) def forward( self, @@ -1221,6 +1282,7 @@ class CrossAttnUpBlock2D(nn.Module): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1322,7 +1384,7 @@ class UNet2DConditionModel(nn.Module): ): super().__init__() assert sample_size is not None, "sample_size must be specified" - print( + logger.info( f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" ) @@ -1331,7 +1393,7 @@ class UNet2DConditionModel(nn.Module): self.out_channels = OUT_CHANNELS self.sample_size = sample_size - self.prepare_config() + self.prepare_config(sample_size=sample_size) # state_dictの書式が変わるのでmoduleの持ち方は変えられない @@ -1418,8 +1480,8 @@ class UNet2DConditionModel(nn.Module): self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) # region diffusers compatibility - def prepare_config(self): - self.config = SimpleNamespace() + def prepare_config(self, *args, **kwargs): + self.config = SimpleNamespace(**kwargs) @property def dtype(self) -> torch.dtype: @@ -1456,7 +1518,7 @@ class UNet2DConditionModel(nn.Module): def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1519,7 +1581,6 @@ class UNet2DConditionModel(nn.Module): # 2. pre-process sample = self.conv_in(sample) - # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 @@ -1604,3 +1665,255 @@ class UNet2DConditionModel(nn.Module): timesteps = timesteps.expand(sample.shape[0]) return timesteps + + +class InferUNet2DConditionModel: + def __init__(self, original_unet: UNet2DConditionModel): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # override original model's up blocks' forward method + for up_block in self.delegate.up_blocks: + if up_block.__class__.__name__ == "UpBlock2D": + + def resnet_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = resnet_wrapper(self.up_block_forward, up_block) + + elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": + + def cross_attn_up_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + logger.info("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + logger.info( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in _self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def cross_attn_up_block_forward( + self, + _self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(_self.resnets, _self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. + """ + + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + _self = self.delegate + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**_self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = _self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=_self.dtype) + emb = _self.time_embedding(t_emb) + + # 2. pre-process + sample = _self.conv_in(sample) + + down_block_res_samples = (sample,) + for depth, downsample_block in enumerate(_self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(_self.up_blocks): + is_final_block = i == len(_self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = _self.conv_norm_out(sample) + sample = _self.conv_act(sample) + sample = _self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 472686ba..a63bd82e 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -5,6 +5,10 @@ from io import BytesIO import os from typing import List, Optional, Tuple, Union import safetensors +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) r""" # Metadata Example @@ -231,7 +235,7 @@ def build_metadata( # # assert all values are filled # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): - print(f"Internal error: some metadata values are None: {metadata}") + logger.error(f"Internal error: some metadata values are None: {metadata}") return metadata diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee405..03b18256 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents): diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 2f0154ca..f03f1bae 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -7,7 +7,10 @@ from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" @@ -100,7 +103,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): key = key.replace(".ln_final", ".final_layer_norm") # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids elif ".embeddings.position_ids" in key: - key = None # remove this key: make position_ids by ourselves + key = None # remove this key: position_ids is not used in newer transformers return key keys = list(checkpoint.keys()) @@ -126,13 +129,15 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # original SD にはないので、position_idsを追加 - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids - # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + # temporary workaround for text_projection.weight.weight for Playground-v2 + if "text_projection.weight.weight" in new_sd: + logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] + del new_sd["text_projection.weight.weight"] + return new_sd, logit_scale @@ -184,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty checkpoint = None # U-Net - print("building U-Net") + logger.info("building U-Net") with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - print("loading U-Net from checkpoint") + logger.info("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) - print("U-Net: ", info) + logger.info(f"U-Net: {info}") # Text Encoders - print("building text encoders") + logger.info("building text encoders") # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( @@ -250,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty with init_empty_weights(): text_model2 = CLIPTextModelWithProjection(text_model2_cfg) - print("loading text encoders from checkpoint") + logger.info("loading text encoders from checkpoint") te1_sd = {} te2_sd = {} for k in list(state_dict.keys()): @@ -258,28 +263,28 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - - # 一部のposition_idsがないモデルへの対応 / add position_ids for some models - if "text_model.embeddings.position_ids" not in te1_sd: - te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) + + # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers + if "text_model.embeddings.position_ids" in te1_sd: + te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 - print("text encoder 1:", info1) + logger.info(f"text encoder 1: {info1}") converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 - print("text encoder 2:", info2) + logger.info(f"text encoder 2: {info2}") # prepare vae - print("building VAE") + logger.info("building VAE") vae_config = model_util.create_vae_diffusers_config() with init_empty_weights(): vae = AutoencoderKL(**vae_config) - print("loading VAE from checkpoint") + logger.info("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) - print("VAE:", info) + logger.info(f"VAE: {info}") ckpt_info = (epoch, global_step) if epoch is not None else None return text_model1, text_model2, vae, unet, logit_scale, ckpt_info diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 26a0af31..17c345a8 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -24,13 +24,18 @@ import math from types import SimpleNamespace -from typing import Optional +from typing import Any, Optional import torch import torch.utils.checkpoint from torch import nn from torch.nn import functional as F from einops import rearrange +from .utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 OUT_CHANNELS: int = 4 @@ -266,6 +271,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class GroupNorm32(nn.GroupNorm): def forward(self, x): if self.weight.dtype != torch.float32: @@ -315,7 +337,7 @@ class ResnetBlock2D(nn.Module): def forward(self, x, emb): if self.training and self.gradient_checkpointing: - # print("ResnetBlock2D: gradient_checkpointing") + # logger.info("ResnetBlock2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -349,7 +371,7 @@ class Downsample2D(nn.Module): def forward(self, hidden_states): if self.training and self.gradient_checkpointing: - # print("Downsample2D: gradient_checkpointing") + # logger.info("Downsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -636,7 +658,7 @@ class BasicTransformerBlock(nn.Module): def forward(self, hidden_states, context=None, timestep=None): if self.training and self.gradient_checkpointing: - # print("BasicTransformerBlock: checkpointing") + # logger.info("BasicTransformerBlock: checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -779,7 +801,7 @@ class Upsample2D(nn.Module): def forward(self, hidden_states, output_size=None): if self.training and self.gradient_checkpointing: - # print("Upsample2D: gradient_checkpointing") + # logger.info("Upsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -1029,7 +1051,7 @@ class SdxlUNet2DConditionModel(nn.Module): for block in blocks: for module in block: if hasattr(module, "set_use_memory_efficient_attention"): - # print(module.__class__.__name__) + # logger.info(module.__class__.__name__) module.set_use_memory_efficient_attention(xformers, mem_eff) def set_use_sdpa(self, sdpa: bool) -> None: @@ -1044,7 +1066,7 @@ class SdxlUNet2DConditionModel(nn.Module): for block in blocks: for module in block.modules(): if hasattr(module, "gradient_checkpointing"): - # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1054,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module): timesteps = timesteps.expand(x.shape[0]) hs = [] - t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) + t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False) t_emb = t_emb.to(x.dtype) emb = self.time_embed(t_emb) @@ -1066,7 +1088,7 @@ class SdxlUNet2DConditionModel(nn.Module): def call_module(module, h, emb, context): x = h for layer in module: - # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) if isinstance(layer, ResnetBlock2D): x = layer(x, emb) elif isinstance(layer, Transformer2DModel): @@ -1077,6 +1099,7 @@ class SdxlUNet2DConditionModel(nn.Module): # h = x.type(self.dtype) h = x + for module in self.input_blocks: h = call_module(module, h, emb, context) hs.append(h) @@ -1093,10 +1116,125 @@ class SdxlUNet2DConditionModel(nn.Module): return h +class InferSdxlUNet2DConditionModel: + def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + logger.info("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + logger.info( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + r""" + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + """ + _self = self.delegate + + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = _self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == _self.dtype + emb = emb + _self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + + for depth, module in enumerate(_self.input_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + # print("downsample", h.shape, self.ds_ratio) + org_dtype = h.dtype + if org_dtype == torch.bfloat16: + h = h.to(torch.float32) + h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(_self.middle_block, h, emb, context) + + for module in _self.output_blocks: + # Deep Shrink + if self.ds_depth_1 is not None: + if hs[-1].shape[-2:] != h.shape[-2:]: + # print("upsample", h.shape, hs[-1].shape) + h = resize_like(h, hs[-1]) + + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + # Deep Shrink: in case of depth 0 + if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]: + # print("upsample", h.shape, x.shape) + h = resize_like(h, x) + + h = h.type(x.dtype) + h = call_module(_self.out, h, emb, context) + + return h + + if __name__ == "__main__": import time - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModel() unet.to("cuda") @@ -1105,7 +1243,7 @@ if __name__ == "__main__": unet.train() # 使用メモリ量確認用の疑似学習ループ - print("preparing optimizer") + logger.info("preparing optimizer") # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working @@ -1120,12 +1258,12 @@ if __name__ == "__main__": scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") if step == 1: time_start = time.perf_counter() @@ -1145,4 +1283,4 @@ if __name__ == "__main__": optimizer.zero_grad(set_to_none=True) time_end = time.perf_counter() - print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f637d993..a29013e3 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -1,14 +1,21 @@ import argparse -import gc import math import os from typing import Optional + import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -17,11 +24,10 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" def load_target_model(args, accelerator, model_version: str, weight_dtype): - # load models for each process model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") ( load_stable_diffusion_format, @@ -47,12 +53,9 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info @@ -64,7 +67,7 @@ def _load_target_model( load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") ( text_encoder1, text_encoder2, @@ -78,7 +81,7 @@ def _load_target_model( from diffusers import StableDiffusionXLPipeline variant = "fp16" if weight_dtype == torch.float16 else None - print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: pipe = StableDiffusionXLPipeline.from_pretrained( @@ -86,12 +89,12 @@ def _load_target_model( ) except EnvironmentError as ex: if variant is not None: - print("try to load fp32 model") + logger.info("try to load fp32 model") pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) else: raise ex except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -114,7 +117,7 @@ def _load_target_model( with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") logit_scale = None ckpt_info = None @@ -122,13 +125,13 @@ def _load_target_model( # VAEを読み込む if vae_path is not None: vae = model_util.load_vae(vae_path, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info def load_tokenizers(args: argparse.Namespace): - print("prepare tokenizers") + logger.info("prepare tokenizers") original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] @@ -137,14 +140,14 @@ def load_tokenizers(args: argparse.Namespace): if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) if tokenizer is None: tokenizer = CLIPTokenizer.from_pretrained(original_path) if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) if i == 1: @@ -153,7 +156,7 @@ def load_tokenizers(args: argparse.Namespace): tokeniers.append(tokenizer) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") return tokeniers @@ -334,23 +337,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: - print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: - print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") # if args.multires_noise_iterations: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # ) # else: # if args.noise_offset is None: # args.noise_offset = DEFAULT_NOISE_OFFSET # elif args.noise_offset != DEFAULT_NOISE_OFFSET: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # ) - # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( not hasattr(args, "weighted_captions") or not args.weighted_captions @@ -359,7 +362,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True - print( + logger.warning( "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" ) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 31b2bd0a..ea765342 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoder_kl import AutoencoderKLOutput - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def slice_h(x, num_slices): # slice with pad 1 both sides: to eliminate side effect of padding of conv2d @@ -62,7 +65,7 @@ def cat_h(sliced): return x -def resblock_forward(_self, num_slices, input_tensor, temb): +def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): assert _self.upsample is None and _self.downsample is None assert _self.norm1.num_groups == _self.norm2.num_groups assert temb is None @@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb): # sliced_tensor = torch.chunk(x, num_div, dim=1) # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) - # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) # normed_tensor = [] # for i in range(num_div): # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) @@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module): self.num_slices = num_slices div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす - # print(f"initial divisor: {div}") + # logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module): for i, down_block in enumerate(self.down_blocks[::-1]): if div >= 2: div = int(div) - # print(f"down block: {i} divisor: {div}") + # logger.info(f"down block: {i} divisor: {div}") for resnet in down_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if down_block.downsamplers is not None: - # print("has downsample") + # logger.info("has downsample") for downsample in down_block.downsamplers: downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) div *= 2 @@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module): def downsample_forward(self, _self, num_slices, hidden_states): assert hidden_states.shape[1] == _self.channels assert _self.use_conv and _self.padding == 0 - print("downsample forward", num_slices, hidden_states.shape) + logger.info(f"downsample forward {num_slices} {hidden_states.shape}") org_device = hidden_states.device cpu_device = torch.device("cpu") @@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module): hidden_states = torch.cat([hidden_states, x], dim=2) hidden_states = hidden_states.to(org_device) - # print("downsample forward done", hidden_states.shape) + # logger.info(f"downsample forward done {hidden_states.shape}") return hidden_states @@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module): self.num_slices = num_slices div = num_slices / (2 ** (len(self.up_blocks) - 1)) - print(f"initial divisor: {div}") + logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module): for i, up_block in enumerate(self.up_blocks): if div >= 2: div = int(div) - # print(f"up block: {i} divisor: {div}") + # logger.info(f"up block: {i} divisor: {div}") for resnet in up_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if up_block.upsamplers is not None: - # print("has upsample") + # logger.info("has upsample") for upsample in up_block.upsamplers: upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) div *= 2 @@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module): del x hidden_states = torch.cat(sliced, dim=2) - # print("us hidden_states", hidden_states.shape) + # logger.info(f"us hidden_states {hidden_states.shape}") del sliced hidden_states = hidden_states.to(org_device) diff --git a/library/train_util.py b/library/train_util.py index cc9ac455..0fec565d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6,6 +6,7 @@ import asyncio import datetime import importlib import json +import logging import pathlib import re import shutil @@ -19,8 +20,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs -import gc +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math import os @@ -31,7 +31,12 @@ from io import BytesIO import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms @@ -58,13 +63,20 @@ from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np from PIL import Image +import imagesize import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +import library.deepspeed_utils as deepspeed_utils +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel @@ -73,6 +85,8 @@ from library.original_unet import UNet2DConditionModel TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +HIGH_VRAM = False + # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" EPOCH_FILE_NAME = "{}-{:06d}" @@ -211,7 +225,7 @@ class BucketManager: self.reso_to_id[reso] = bucket_id self.resos.append(reso) self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) + # logger.info(reso, bucket_id, len(self.buckets)) def round_to_steps(self, x): x = int(x + 0.5) @@ -237,7 +251,7 @@ class BucketManager: scale = reso[0] / image_width resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) - # print("use predef", image_width, image_height, reso, resized_size) + # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") else: # 縮小のみを行う if image_width * image_height > self.max_area: @@ -256,21 +270,21 @@ class BucketManager: b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) ar_height_rounded = b_width_in_hr / b_height_rounded - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) + # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) else: resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) - # print(resized_size) + # logger.info(resized_size) else: resized_size = (image_width, image_height) # リサイズは不要 # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) bucket_width = resized_size[0] - resized_size[0] % self.reso_steps bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") reso = (bucket_width, bucket_height) @@ -349,7 +363,11 @@ class BaseSubset: image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, + caption_separator: str, keep_tokens: int, + keep_tokens_separator: str, + secondary_separator: Optional[str], + enable_wildcard: bool, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -365,7 +383,11 @@ class BaseSubset: self.image_dir = image_dir self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption + self.caption_separator = caption_separator self.keep_tokens = keep_tokens + self.keep_tokens_separator = keep_tokens_separator + self.secondary_separator = secondary_separator + self.enable_wildcard = enable_wildcard self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -389,9 +411,14 @@ class DreamBoothSubset(BaseSubset): is_reg: bool, class_tokens: Optional[str], caption_extension: str, + cache_info: bool, num_repeats, shuffle_caption, + caption_separator: str, keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -410,7 +437,11 @@ class DreamBoothSubset(BaseSubset): image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -429,6 +460,7 @@ class DreamBoothSubset(BaseSubset): self.caption_extension = caption_extension if self.caption_extension and not self.caption_extension.startswith("."): self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info def __eq__(self, other) -> bool: if not isinstance(other, DreamBoothSubset): @@ -443,7 +475,11 @@ class FineTuningSubset(BaseSubset): metadata_file: str, num_repeats, shuffle_caption, + caption_separator, keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -462,7 +498,11 @@ class FineTuningSubset(BaseSubset): image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -490,9 +530,14 @@ class ControlNetSubset(BaseSubset): image_dir: str, conditioning_data_dir: str, caption_extension: str, + cache_info: bool, num_repeats, shuffle_caption, + caption_separator, keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -511,7 +556,11 @@ class ControlNetSubset(BaseSubset): image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -529,6 +578,7 @@ class ControlNetSubset(BaseSubset): self.caption_extension = caption_extension if self.caption_extension and not self.caption_extension.startswith("."): self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info def __eq__(self, other) -> bool: if not isinstance(other, ControlNetSubset): @@ -542,6 +592,7 @@ class BaseDataset(torch.utils.data.Dataset): tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], max_token_length: int, resolution: Optional[Tuple[int, int]], + network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() @@ -551,6 +602,7 @@ class BaseDataset(torch.utils.data.Dataset): self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] @@ -645,16 +697,67 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: + # process wildcards + if subset.enable_wildcard: + # if caption is multiline, random choice one line + if "\n" in caption: + caption = random.choice(caption.split("\n")) + + # wildcard is like '{aaa|bbb|ccc...}' + # escape the curly braces like {{ or }} + replacer1 = "⦅" + replacer2 = "⦆" + while replacer1 in caption or replacer2 in caption: + replacer1 += "⦅" + replacer2 += "⦆" + + caption = caption.replace("{{", replacer1).replace("}}", replacer2) + + # replace the wildcard + def replace_wildcard(match): + return random.choice(match.group(1).split("|")) + + caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption) + + # unescape the curly braces + caption = caption.replace(replacer1, "{").replace(replacer2, "}") + else: + # if caption is multiline, use the first line + caption = caption.split("\n")[0] + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + fixed_tokens = [] + flex_tokens = [] + fixed_suffix_tokens = [] + if ( + hasattr(subset, "keep_tokens_separator") + and subset.keep_tokens_separator + and subset.keep_tokens_separator in caption + ): + fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + if subset.keep_tokens_separator in flex_part: + flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1) + fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()] + + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] + flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] + else: + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens :] + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( - math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + math.floor( + (self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) + ) + subset.token_warmup_min ) - tokens = tokens[:tokens_len] + flex_tokens = flex_tokens[:tokens_len] def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -665,18 +768,16 @@ class BaseDataset(torch.utils.data.Dataset): l.append(token) return l - fixed_tokens = [] - flex_tokens = tokens[:] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = tokens[subset.keep_tokens :] - if subset.shuffle_caption: random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) - caption = ", ".join(fixed_tokens + flex_tokens) + caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens) + + # process secondary separator + if subset.secondary_separator: + caption = caption.replace(subset.secondary_separator, subset.caption_separator) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -749,15 +850,15 @@ class BaseDataset(torch.utils.data.Dataset): bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) min_size and max_size are ignored when enable_bucket is False """ - print("loading image sizes.") + logger.info("loading image sizes.") for info in tqdm(self.image_data.values()): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) if self.enable_bucket: - print("make buckets") + logger.info("make buckets") else: - print("prepare dataset") + logger.info("prepare dataset") # bucketを作成し、画像をbucketに振り分ける if self.enable_bucket: @@ -772,7 +873,7 @@ class BaseDataset(torch.utils.data.Dataset): if not self.bucket_no_upscale: self.bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -783,7 +884,7 @@ class BaseDataset(torch.utils.data.Dataset): image_width, image_height ) - # print(image_info.image_key, image_info.bucket_reso) + # logger.info(image_info.image_key, image_info.bucket_reso) img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() @@ -801,17 +902,17 @@ class BaseDataset(torch.utils.data.Dataset): # bucket情報を表示、格納する if self.enable_bucket: self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") img_ar_errors = np.array(img_ar_errors) mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") + logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List(BucketBatchIndex) = [] @@ -831,7 +932,7 @@ class BaseDataset(torch.utils.data.Dataset): # num_of_image_types = len(set(bucket)) # bucket_batch_size = min(self.batch_size, num_of_image_types) # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) # for batch_index in range(batch_count): # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) # ↑ここまで @@ -870,7 +971,7 @@ class BaseDataset(torch.utils.data.Dataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching latents.") + logger.info("caching latents.") image_infos = list(self.image_data.values()) @@ -880,7 +981,7 @@ class BaseDataset(torch.utils.data.Dataset): # split by resolution batches = [] batch = [] - print("checking cache validity...") + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -917,7 +1018,7 @@ class BaseDataset(torch.utils.data.Dataset): return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - print("caching latents...") + logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) @@ -931,10 +1032,10 @@ class BaseDataset(torch.utils.data.Dataset): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching text encoder outputs.") + logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) - print("checking cache existence...") + logger.info("checking cache existence...") image_infos_to_cache = [] for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] @@ -975,7 +1076,7 @@ class BaseDataset(torch.utils.data.Dataset): batches.append(batch) # iterate batches: call text encoder and cache outputs for memory or disk - print("caching text encoder outputs...") + logger.info("caching text encoder outputs...") for batch in tqdm(batches): infos, input_ids1, input_ids2 = zip(*batch) input_ids1 = torch.stack(input_ids1, dim=0) @@ -985,8 +1086,7 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size + return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str): img = load_image(image_path) @@ -1078,7 +1178,9 @@ class BaseDataset(torch.utils.data.Dataset): for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + loss_weights.append( + self.prior_loss_weight if image_info.is_reg else 1.0 + ) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1244,6 +1346,8 @@ class BaseDataset(torch.utils.data.Dataset): example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) example["flippeds"] = flippeds + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1311,6 +1415,8 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): + IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1318,15 +1424,16 @@ class DreamBoothDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1353,7 +1460,7 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension): + def read_caption(img_path, caption_extension, enable_wildcard): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name @@ -1369,38 +1476,69 @@ class DreamBoothDataset(BaseDataset): try: lines = f.readlines() except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() break return caption def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") - return [], [] + logger.warning(f"not directory: {subset.image_dir}") + return [], [], [] - img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: - print( - f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) + use_cached_info_for_subset = subset.cache_info + if use_cached_info_for_subset: + logger.info( + f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}" + ) + if not os.path.isfile(info_cache_file): + logger.warning( + f"image info file not found. You can ignore this warning if this is the first time to use this subset" + + " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}" ) - captions.append("") - missing_captions.append(img_path) - else: - if cap_for_img is None: - captions.append(subset.class_tokens) + use_cached_info_for_subset = False + + if use_cached_info_for_subset: + # json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...} + with open(info_cache_file, "r", encoding="utf-8") as f: + metas = json.load(f) + img_paths = list(metas.keys()) + sizes = [meta["resolution"] for meta in metas.values()] + + # we may need to check image size and existence of image files, but it takes time, so user should check it before training + else: + img_paths = glob_images(subset.image_dir, "*") + sizes = [None] * len(img_paths) + + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + + if use_cached_info_for_subset: + captions = [meta["caption"] for meta in metas.values()] + missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] + else: + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + if cap_for_img is None and subset.class_tokens is None: + logger.warning( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) + captions.append("") missing_captions.append(img_path) else: - captions.append(cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 @@ -1409,36 +1547,50 @@ class DreamBoothDataset(BaseDataset): number_of_missing_captions_to_show = 5 remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - print( + logger.warning( f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" ) for i, missing_caption in enumerate(missing_captions): if i >= number_of_missing_captions_to_show: - print(missing_caption + f"... and {remaining_missing_captions} more") + logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break - print(missing_caption) - return img_paths, captions + logger.warning(missing_caption) - print("prepare images.") + if not use_cached_info_for_subset and subset.cache_info: + logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}") + sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")] + matas = {} + for img_path, caption, size in zip(img_paths, captions, sizes): + matas[img_path] = {"caption": caption, "resolution": list(size)} + with open(info_cache_file, "w", encoding="utf-8") as f: + json.dump(matas, f, ensure_ascii=False, indent=2) + logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}") + + # if sizes are not set, image size will be read in make_buckets + return img_paths, captions, sizes + + logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 - reg_infos: List[ImageInfo] = [] + reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue - img_paths, captions = load_dreambooth_dir(subset) + img_paths, captions, sizes = load_dreambooth_dir(subset) if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" + ) continue if subset.is_reg: @@ -1446,31 +1598,33 @@ class DreamBoothDataset(BaseDataset): else: num_train_images += subset.num_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): + for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size if subset.is_reg: - reg_infos.append(info) + reg_infos.append((info, subset)) else: self.register_image(info, subset) subset.img_count = len(img_paths) self.subsets.append(subset) - print(f"{num_train_images} train images with repeating.") + logger.info(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images - print(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") + logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 first_loop = True while n < num_train_images: - for info in reg_infos: + for info, subset in reg_infos: if first_loop: self.register_image(info, subset) n += info.num_repeats @@ -1492,14 +1646,15 @@ class FineTuningDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1508,27 +1663,29 @@ class FineTuningDataset(BaseDataset): for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue # メタデータを読み込む if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") + logger.info(f"loading existing metadata: {subset.metadata_file}") with open(subset.metadata_file, "rt", encoding="utf-8") as f: metadata = json.load(f) else: raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します" + ) continue tags_list = [] @@ -1559,10 +1716,24 @@ class FineTuningDataset(BaseDataset): caption = img_md.get("caption") tags = img_md.get("tags") if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ", " + tags - tags_list.append(tags) + caption = tags # could be multiline + tags = None + + if subset.enable_wildcard: + # tags must be single line + if tags is not None: + tags = tags.replace("\n", subset.caption_separator) + + # add tags to each line of caption + if caption is not None and tags is not None: + caption = "\n".join( + [f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""] + ) + else: + # use as is + if tags is not None and len(tags) > 0: + caption = caption + subset.caption_separator + tags + tags_list.append(tags) if caption is None: caption = "" @@ -1606,14 +1777,16 @@ class FineTuningDataset(BaseDataset): if not npz_any: use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + logger.warning( + f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" + ) if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") # check min/max bucket size sizes = set() @@ -1629,7 +1802,9 @@ class FineTuningDataset(BaseDataset): if sizes is None: if use_npz_latents: use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.warning( + f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" + ) assert ( resolution is not None @@ -1643,8 +1818,8 @@ class FineTuningDataset(BaseDataset): self.bucket_no_upscale = bucket_no_upscale else: if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") + logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") self.enable_bucket = True assert ( @@ -1696,25 +1871,34 @@ class ControlNetDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: + assert ( + not subset.random_crop + ), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません" db_subset = DreamBoothSubset( subset.image_dir, False, None, subset.caption_extension, + subset.cache_info, subset.num_repeats, subset.shuffle_caption, + subset.caption_separator, subset.keep_tokens, + subset.keep_tokens_separator, + subset.secondary_separator, + subset.enable_wildcard, subset.color_aug, subset.flip_aug, subset.face_crop_aug_range, @@ -1735,6 +1919,7 @@ class ControlNetDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier, enable_bucket, min_bucket_reso, max_bucket_reso, @@ -1752,7 +1937,7 @@ class ControlNetDataset(BaseDataset): # assert all conditioning data exists missing_imgs = [] - cond_imgs_with_img = set() + cond_imgs_with_pair = set() for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] subset = None @@ -1763,26 +1948,32 @@ class ControlNetDataset(BaseDataset): assert subset is not None, "internal error: subset not found" if not os.path.isdir(subset.conditioning_data_dir): - print(f"not directory: {subset.conditioning_data_dir}") + logger.warning(f"not directory: {subset.conditioning_data_dir}") continue - img_basename = os.path.basename(info.absolute_path) - ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) - if not os.path.exists(ctrl_img_path): + img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0] + ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename) + if len(ctrl_img_path) < 1: missing_imgs.append(img_basename) + continue + ctrl_img_path = ctrl_img_path[0] + ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path info.cond_img_path = ctrl_img_path - cond_imgs_with_img.add(ctrl_img_path) + cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive extra_imgs = [] for subset in subsets: conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - extra_imgs.extend( - [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] - ) + conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path + extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -1821,7 +2012,9 @@ class ControlNetDataset(BaseDataset): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + cond_img = cv2.resize( + cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + ) # INTER_AREAでやりたいのでcv2でリサイズ # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -1883,14 +2076,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def set_caching_mode(self, caching_mode): @@ -1974,12 +2167,15 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") + logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + logger.info( + "`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します" + ) epoch = 1 while True: - print(f"\nepoch: {epoch}") + logger.info(f"") + logger.info(f"epoch: {epoch}") steps = (epoch - 1) * len(train_dataset) + 1 indices = list(range(len(train_dataset))) @@ -1989,11 +2185,11 @@ def debug_dataset(train_dataset, show_input_ids=False): for i, idx in enumerate(indices): train_dataset.set_current_epoch(epoch) train_dataset.set_current_step(steps) - print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") example = train_dataset[idx] if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") + logger.info(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], @@ -2006,24 +2202,26 @@ def debug_dataset(train_dataset, show_input_ids=False): example["flippeds"], ) ): - print( + logger.info( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) + if "network_multipliers" in example: + print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: - print(f"input ids: {iid}") + logger.info(f"input ids: {iid}") if "input_ids2" in example: - print(f"input ids2: {example['input_ids2'][j]}") + logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] - print(f"image size: {im.size()}") + logger.info(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) if "conditioning_images" in example: cond_img = example["conditioning_images"][j] - print(f"conditioning image size: {cond_img.size()}") + logger.info(f"conditioning image size: {cond_img.size()}") cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) cond_img = np.transpose(cond_img, (1, 2, 0)) cond_img = cond_img[:, :, ::-1] @@ -2075,8 +2273,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2171,12 +2369,12 @@ def trim_and_resize_if_required( if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) + # logger.info(f"w {trim_size} {p}") image = image[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) + # logger.info(f"h {trim_size} {p}) image = image[p : p + reso[1]] # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない @@ -2236,9 +2434,8 @@ def cache_batch_latents( if flip_aug: info.latents_flipped = flipped_latent - # FIXME this slows down caching a lot, specify this as an option - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if not HIGH_VRAM: + clean_memory_on_device(vae.device) def cache_batch_text_encoder_outputs( @@ -2414,7 +2611,7 @@ def get_git_revision_hash() -> str: # def replace_unet_cross_attn_to_xformers(): -# print("CrossAttention.forward has been replaced to enable xformers.") +# logger.info("CrossAttention.forward has been replaced to enable xformers.") # try: # import xformers.ops # except ImportError: @@ -2457,10 +2654,10 @@ def get_git_revision_hash() -> str: # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -2468,7 +2665,7 @@ def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdp unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_sdpa(True) @@ -2479,17 +2676,17 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") + logger.info("Use Diffusers xformers for VAE") vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + logger.info("forward_flash_attn") q_bucket_size = 512 k_bucket_size = 1024 @@ -2634,7 +2831,9 @@ def get_sai_model_spec( def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) @@ -2657,7 +2856,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2674,7 +2873,10 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", ) parser.add_argument( @@ -2721,13 +2923,23 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument( - "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名" + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" ) parser.add_argument( - "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類" + "--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名" + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", ) parser.add_argument( "--huggingface_path_in_repo", @@ -2763,10 +2975,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="precision in saving / 保存時に精度を変更して保存する", ) parser.add_argument( - "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", ) parser.add_argument( - "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", ) parser.add_argument( "--save_n_epoch_ratio", @@ -2801,7 +3019,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", + ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する", ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -2818,6 +3041,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) + parser.add_argument( + "--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う" + ) + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", + # available backends: + # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 + # https://pytorch.org/docs/stable/torch.compiler.html + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -2825,7 +3061,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", ) parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") @@ -2857,18 +3096,34 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + parser.add_argument( "--ddp_timeout", type=int, default=None, help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", ) + parser.add_argument( + "--ddp_gradient_as_bucket_view", + action="store_true", + help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", + ) + parser.add_argument( + "--ddp_static_graph", + action="store_true", + help="enable static_graph for DDP / DDPでstatic_graphを有効にする", + ) parser.add_argument( "--clip_skip", type=int, @@ -2888,13 +3143,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=["tensorboard", "wandb", "all"], help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", ) - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) parser.add_argument( "--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) parser.add_argument( "--log_tracker_config", type=str, @@ -2907,12 +3170,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument( "--noise_offset", type=float, default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--noise_offset_random_strength", + action="store_true", + help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", + ) parser.add_argument( "--multires_noise_iterations", type=int, @@ -2926,6 +3195,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + action="store_true", + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -2961,15 +3236,48 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber", "smooth_l1"], + help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + ) + parser.add_argument( + "--huber_schedule", + type=str, + default="snr", + choices=["constant", "exponential", "snr"], + help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr" + + " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトは snr", + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.1, + help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) parser.add_argument( "--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + parser.add_argument( + "--highvram", + action="store_true", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", ) parser.add_argument( - "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" ) parser.add_argument( "--sample_every_n_epochs", @@ -2978,7 +3286,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", ) parser.add_argument( - "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", ) parser.add_argument( "--sample_sampler", @@ -3054,15 +3365,94 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) +def add_masked_loss_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + + +# verify command line args for training +def verify_command_line_training_args(args: argparse.Namespace): + # if wandb is enabled, the command line is exposed to the public + # check whether sensitive options are included in the command line arguments + # if so, warn or inform the user to move them to the configuration file + # wandbが有効な場合、コマンドラインが公開される + # 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、 + # 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する + + wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb" + if not wandb_enabled: + return + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + + for arg in sensitive_args: + if getattr(args, arg, None) is not None: + logger.warning( + f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file." + + f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。" + ) + + # if path is absolute, it may include sensitive information + for arg in sensitive_path_args: + if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)): + logger.info( + f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path." + + f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。" + ) + + if getattr(args, "config_file", None) is not None: + logger.info( + f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path." + + f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。" + ) + + # other sensitive options + if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public": + logger.info( + f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file." + + f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。" + ) + + def verify_training_args(args: argparse.Namespace): + r""" + Verify training arguments. Also reflect highvram option to global variable + 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する + """ + if args.highvram: + print("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + logger.warning( + "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" + ) if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True - print( + logger.warning( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) @@ -3093,20 +3483,41 @@ def verify_training_args(args: argparse.Namespace): ) if args.zero_terminal_snr and not args.v_parameterization: - print( + logger.warning( f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) + if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0: + logger.warning( + "sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_epochs = None + + if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0: + logger.warning( + "sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_steps = None + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) + parser.add_argument( + "--cache_info", + action="store_true", + help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth" + + " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効", + ) + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" ) @@ -3122,6 +3533,25 @@ def add_dataset_arguments( default=0, help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", ) + parser.add_argument( + "--keep_tokens_separator", + type=str, + default="", + help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." + + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", + ) + parser.add_argument( + "--secondary_separator", + type=str, + default=None, + help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption" + + " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる", + ) + parser.add_argument( + "--enable_wildcard", + action="store_true", + help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする(例:'{image|picture|rendition}')", + ) parser.add_argument( "--caption_prefix", type=str, @@ -3134,8 +3564,12 @@ def add_dataset_arguments( default=None, help="suffix for caption text / captionのテキストの末尾に付ける文字列", ) - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument( + "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" + ) + parser.add_argument( + "--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする" + ) parser.add_argument( "--face_crop_aug_range", type=str, @@ -3148,7 +3582,9 @@ def add_dataset_arguments( help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", ) parser.add_argument( - "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" + "--debug_dataset", + action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)", ) parser.add_argument( "--resolution", @@ -3161,14 +3597,18 @@ def add_dataset_arguments( action="store_true", help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) - parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ" + ) parser.add_argument( "--cache_latents_to_disk", action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) parser.add_argument( - "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" + "--enable_bucket", + action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") @@ -3179,7 +3619,9 @@ def add_dataset_arguments( help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( @@ -3223,13 +3665,20 @@ def add_dataset_arguments( if support_dreambooth: # DreamBooth dataset - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + parser.add_argument( + "--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ" + ) if support_caption: # caption dataset - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") parser.add_argument( - "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" + "--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル" + ) + parser.add_argument( + "--dataset_repeats", + type=int, + default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数", ) @@ -3257,7 +3706,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar if args.output_config: # check if config file exists if os.path.exists(config_path): - print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") exit(1) # convert args to dictionary @@ -3285,15 +3734,15 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar with open(config_path, "w") as f: toml.dump(args_dict, f) - print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") exit(0) if not os.path.exists(config_path): - print(f"{config_path} not found.") + logger.info(f"{config_path} not found.") exit(1) - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: + logger.info(f"Loading settings from {config_path}...") + with open(config_path, "r", encoding="utf-8") as f: config_dict = toml.load(f) # combine all sections into one @@ -3311,7 +3760,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - print(args.config_file) + logger.info(args.config_file) return args @@ -3326,11 +3775,11 @@ def resume_from_local_or_hf_if_specified(accelerator, args): return if not args.resume_from_huggingface: - print(f"resume training from local state: {args.resume}") + logger.info(f"resume training from local state: {args.resume}") accelerator.load_state(args.resume) return - print(f"resume training from huggingface state: {args.resume}") + logger.info(f"resume training from huggingface state: {args.resume}") repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] path_in_repo = "/".join(args.resume.split("/")[2:]) revision = None @@ -3342,7 +3791,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): repo_type = "model" else: path_in_repo, revision, repo_type = divided - print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") list_files = huggingface_util.list_dir( repo_id=repo_id, @@ -3367,13 +3816,15 @@ def resume_from_local_or_hf_if_specified(accelerator, args): loop = asyncio.get_event_loop() results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) if len(results) == 0: - raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした") + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) dirname = os.path.dirname(results[0]) accelerator.load_state(dirname) def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -3414,7 +3865,7 @@ def get_optimizer(args, trainable_params): # value = tuple(value) optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + # logger.info(f"optkwargs {optimizer}_{kwargs}") lr = args.learning_rate optimizer = None @@ -3424,7 +3875,7 @@ def get_optimizer(args, trainable_params): import lion_pytorch except ImportError: raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") + logger.info(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3435,14 +3886,14 @@ def get_optimizer(args, trainable_params): raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") if optimizer_type == "AdamW8bit".lower(): - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") optimizer_class = bnb.optim.AdamW8bit optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov8bit".lower(): - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print( + logger.warning( f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" ) optimizer_kwargs["momentum"] = 0.9 @@ -3451,7 +3902,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.Lion8bit except AttributeError: @@ -3459,7 +3910,7 @@ def get_optimizer(args, trainable_params): "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedAdamW8bit except AttributeError: @@ -3467,7 +3918,7 @@ def get_optimizer(args, trainable_params): "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedLion8bit except AttributeError: @@ -3477,8 +3928,22 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW".lower(): + logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW + except AttributeError: + raise AttributeError( + "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW32bit".lower(): - print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3492,16 +3957,18 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + logger.info( + f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) optimizer_kwargs["momentum"] = 0.9 optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): - # check lr and lr_count, and print warning + # check lr and lr_count, and logger.info warning actual_lr = lr lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: @@ -3512,12 +3979,12 @@ def get_optimizer(args, trainable_params): lr_count = len(lrs) if actual_lr <= 0.1: - print( + logger.warning( f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) - print("recommend option: lr=1.0 / 推奨は1.0です") + logger.warning("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: - print( + logger.warning( f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) @@ -3533,25 +4000,25 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdam".lower(): optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") @@ -3564,7 +4031,7 @@ def get_optimizer(args, trainable_params): except ImportError: raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - print(f"use Prodigy optimizer | {optimizer_kwargs}") + logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") optimizer_class = prodigyopt.Prodigy optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3573,14 +4040,16 @@ def get_optimizer(args, trainable_params): if "relative_step" not in optimizer_kwargs: optimizer_kwargs["relative_step"] = True # default if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") + logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する @@ -3592,37 +4061,37 @@ def get_optimizer(args, trainable_params): if has_group_lr: # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") args.unet_lr = None args.text_encoder_lr = None if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど lr = None else: if args.max_grad_norm != 0.0: - print( + logger.warning( f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" ) if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") optimizer_class = transformers.optimization.Adafactor optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") + logger.info(f"use {optimizer_type} | {optimizer_kwargs}") if "." not in optimizer_type: optimizer_module = torch.optim else: @@ -3668,7 +4137,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type - print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") if "." not in lr_scheduler_type: # default to use torch.optim lr_scheduler_module = torch.optim.lr_scheduler else: @@ -3684,7 +4153,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): type(optimizer) == transformers.optimization.Adafactor ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) - # print("adafactor scheduler init lr", initial_lr) + # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) @@ -3749,20 +4218,20 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): if support_metadata: if args.in_json is not None and (args.color_aug or args.random_crop): - print( + logger.warning( f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") + logger.info("prepare tokenizer") original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH tokenizer: CLIPTokenizer = None if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 if tokenizer is None: @@ -3772,16 +4241,20 @@ def load_tokenizer(args: argparse.Namespace): tokenizer = CLIPTokenizer.from_pretrained(original_path) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) return tokenizer def prepare_accelerator(args: argparse.Namespace): + """ + this function also prepares deepspeed plugin + """ + if args.logging_dir is None: logging_dir = None else: @@ -3797,7 +4270,9 @@ def prepare_accelerator(args: argparse.Namespace): log_with = args.log_with if log_with in ["tensorboard", "all"]: if logging_dir is None: - raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) if log_with in ["wandb", "all"]: try: import wandb @@ -3809,16 +4284,34 @@ def prepare_accelerator(args: argparse.Namespace): if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + # torch.compile のオプション。 NO の場合は torch.compile は使わない + dynamo_backend = "NO" + if args.torch_compile: + dynamo_backend = args.dynamo_backend + kwargs_handlers = ( - None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))] + InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ) + kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, project_dir=logging_dir, kwargs_handlers=kwargs_handlers, + dynamo_backend=dynamo_backend, + deepspeed_plugin=deepspeed_plugin, ) + print("accelerator device:", accelerator.device) return accelerator @@ -3845,17 +4338,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) else: # Diffusers model is loaded to CPU - print(f"load Diffusers pretrained models: {name_or_path}") + logger.info(f"load Diffusers pretrained models: {name_or_path}") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -3866,7 +4359,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une # Diffusers U-Net to original U-Net # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう - # print(f"unet config: {unet.config}") + # logger.info(f"unet config: {unet.config}") original_unet = UNet2DConditionModel( unet.config.sample_size, unet.config.attention_head_dim, @@ -3876,32 +4369,20 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return text_encoder, vae, unet, load_stable_diffusion_format -# TODO remove this function in the future -def transform_if_model_is_DDP(text_encoder, unet, network=None): - # Transform text_encoder, unet and network from DistributedDataParallel - return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) - - -def transform_models_if_DDP(models): - # Transform text_encoder, unet and network from DistributedDataParallel - return [model.module if type(model) == DDP else model for model in models if model is not None] - - def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): - # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( args, @@ -3909,19 +4390,14 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, ) - # work on low-ram device if args.lowram: text_encoder.to(accelerator.device) unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - - text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet) - return text_encoder, vae, unet, load_stable_diffusion_format @@ -3970,7 +4446,9 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod # v1: ... の三連を ... へ戻す states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append( + encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] + ) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) @@ -4033,6 +4511,7 @@ def get_hidden_states_sdxl( text_encoder1: CLIPTextModel, text_encoder2: CLIPTextModelWithProjection, weight_dtype: Optional[str] = None, + accelerator: Optional[Accelerator] = None, ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] @@ -4048,7 +4527,8 @@ def get_hidden_states_sdxl( hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer # pool2 = enc_out["text_embeds"] - pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2) + pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 n_size = 1 if max_token_length is None else max_token_length // 75 @@ -4210,7 +4690,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( ckpt_name = get_step_ckpt_name(args, ext, global_step) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: @@ -4225,7 +4706,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) if os.path.exists(remove_ckpt_file): - print(f"removing old checkpoint: {remove_ckpt_file}") + logger.info(f"removing old checkpoint: {remove_ckpt_file}") os.remove(remove_ckpt_file) else: @@ -4234,7 +4715,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( else: out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - print(f"\nsaving model: {out_dir}") + logger.info("") + logger.info(f"saving model: {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4248,7 +4730,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) if os.path.exists(remove_out_dir): - print(f"removing old model: {remove_out_dir}") + logger.info(f"removing old model: {remove_out_dir}") shutil.rmtree(remove_out_dir) if args.save_state: @@ -4261,13 +4743,14 @@ def save_sd_model_on_epoch_end_or_stepwise_common( def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - print(f"\nsaving state at epoch {epoch_no}") + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs @@ -4275,20 +4758,21 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - print(f"\nsaving state at step {step_no}") + logger.info("") + logger.info(f"saving state at step {step_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps @@ -4300,21 +4784,22 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n if remove_step_no > 0: state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - print("\nsaving last state.") + logger.info("") + logger.info("saving last state.") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading last state to huggingface.") + logger.info("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -4363,7 +4848,7 @@ def save_sd_model_on_train_end_common( ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") sd_saver(ckpt_file, epoch, global_step) if args.huggingface_repo_id is not None: @@ -4372,18 +4857,54 @@ def save_sd_model_on_train_end_common( out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") + logger.info(f"save trained model as Diffusers to {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) +def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): + + # TODO: if a huber loss is selected, it will use constant timesteps for each batch + # as. In the future there may be a smarter way + + if args.loss_type == "huber" or args.loss_type == "smooth_l1": + timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") + timestep = timesteps.item() + + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + huber_c = math.exp(-alpha * timestep) + elif args.huber_schedule == "snr": + alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + elif args.huber_schedule == "constant": + huber_c = args.huber_c + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + timesteps = timesteps.repeat(b_size).to(device) + elif args.loss_type == "l2": + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + huber_c = 1 # may be anything, as it's not used + else: + raise NotImplementedError(f"Unknown loss type {args.loss_type}") + timesteps = timesteps.long() + + return timesteps, huber_c + + def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount @@ -4394,17 +4915,44 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) - timesteps = timesteps.long() + timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps + return noise, noisy_latents, timesteps, huber_c + + +# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already +def conditional_loss( + model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 +): + + if loss_type == "l2": + loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "huber": + loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + elif loss_type == "smooth_l1": + loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + else: + raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + return loss def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): @@ -4437,13 +4985,119 @@ SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +def get_my_scheduler( + *, + sample_sampler: str, + v_parameterization: bool, +): + sched_init_args = {} + if sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sample_sampler == "lms" or sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sample_sampler == "euler" or sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # logger.info("set clip_sample to True") + scheduler.config.clip_sample = True + + return scheduler + + def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict["prompt"] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["width"] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["height"] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["seed"] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict["scale"] = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict["negative_prompt"] = m.group(1) + continue + + m = re.match(r"ss (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["sample_sampler"] = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["controlnet_image"] = m.group(1) + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) + + return prompt_dict + + def sample_images_common( pipe_class, - accelerator, + accelerator: Accelerator, args: argparse.Namespace, epoch, steps, @@ -4458,29 +5112,40 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: + + if steps == 0: + if not args.sample_at_first: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + + # unwrap unet and text_encoder(s) + unet = accelerator.unwrap_model(unet) + if isinstance(text_encoder, (list, tuple)): + text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] + else: + text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4493,198 +5158,75 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, + # schedulers: dict = {} cannot find where this is used + default_scheduler = get_my_scheduler( + sample_sampler=args.sample_sampler, + v_parameterization=args.v_parameterization, ) - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True - pipeline = pipe_class( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=default_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + # save random state to restore later rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass - with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue - - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - controlnet_image = prompt.get("controlnet_image") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue - - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - controlnet_image = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt - controlnet_image = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - image = pipeline.latents_to_image(latents)[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + + # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. + # with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) if cuda_rng_state is not None: @@ -4692,8 +5234,105 @@ def sample_images_common( vae.to(org_vae_device) +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + pipeline, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, + controlnet=None, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + + image = pipeline.latents_to_image(latents)[0] + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + # endregion + # region 前処理用 @@ -4712,7 +5351,7 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # convert to tensor temporarily so dataloader will accept it tensor_pil = transforms.functional.pil_to_tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor_pil, img_path) diff --git a/library/utils.py b/library/utils.py index 7d801a67..3037c055 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,266 @@ +import logging +import sys import threading +import torch +from torchvision import transforms from typing import * +from diffusers import EulerAncestralDiscreteScheduler +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +def add_logging_arguments(parser): + parser.add_argument( + "--console_log_level", + type=str, + default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", + ) + parser.add_argument( + "--console_log_file", + type=str, + default=None, + help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", + ) + parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") + + +def setup_logging(args=None, log_level=None, reset=False): + if logging.root.handlers: + if reset: + # remove all handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + else: + return + + # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO + if log_level is None and args is not None: + log_level = args.console_log_level + if log_level is None: + log_level = "INFO" + log_level = getattr(logging, log_level) + + msg_init = None + if args is not None and args.console_log_file: + handler = logging.FileHandler(args.console_log_file, mode="w") + else: + handler = None + if not args or not args.console_log_simple: + try: + from rich.logging import RichHandler + from rich.console import Console + from rich.logging import RichHandler + + handler = RichHandler(console=Console(stderr=True)) + except ImportError: + # print("rich is not installed, using basic logging") + msg_init = "rich is not installed, using basic logging" + + if handler is None: + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False + + formatter = logging.Formatter( + fmt="%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logging.root.setLevel(log_level) + logging.root.addHandler(handler) + + if msg_init is not None: + logger = logging.getLogger(__name__) + logger.info(msg_init) + + + +# TODO make inf_utils.py + + +# region Gradual Latent hires fix + + +class GradualLatent: + def __init__( + self, + ratio, + start_timesteps, + every_n_steps, + ratio_step, + s_noise=1.0, + gaussian_blur_ksize=None, + gaussian_blur_sigma=0.5, + gaussian_blur_strength=0.5, + unsharp_target_x=True, + ): + self.ratio = ratio + self.start_timesteps = start_timesteps + self.every_n_steps = every_n_steps + self.ratio_step = ratio_step + self.s_noise = s_noise + self.gaussian_blur_ksize = gaussian_blur_ksize + self.gaussian_blur_sigma = gaussian_blur_sigma + self.gaussian_blur_strength = gaussian_blur_strength + self.unsharp_target_x = unsharp_target_x + + def __str__(self) -> str: + return ( + f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" + ) + + def apply_unshark_mask(self, x: torch.Tensor): + if self.gaussian_blur_ksize is None: + return x + blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) + # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) + mask = (x - blurred) * self.gaussian_blur_strength + sharpened = x + mask + return sharpened + + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + self.gradual_latent = None + + def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): + self.resized_size = size + self.gradual_latent = gradual_latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + # logger.warning( + print( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + device = model_output.device + if self.resized_size is None: + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + s_noise = 1.0 + else: + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise + + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up * s_noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +# endregion diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 51f581b2..794659c9 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -2,10 +2,13 @@ import argparse import os import torch from safetensors.torch import load_file - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(file): - print(f"loading: {file}") + logger.info(f"loading: {file}") if os.path.splitext(file)[1] == ".safetensors": sd = load_file(file) else: diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 4ebfef7a..c9377bee 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -2,7 +2,10 @@ import os from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module): return # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している @@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module): # create module instances self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) - print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): return x # dummy @@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module): return info def apply_to(self): - print("applying LLLite for U-Net...") + logger.info("applying LLLite for U-Net...") for module in self.unet_modules: module.apply_to() self.add_module(module.lllite_name, module) @@ -374,19 +377,19 @@ if __name__ == "__main__": # sdxl_original_unet.USE_REENTRANT = False # test shape etc - print("create unet") + logger.info("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create ControlNet-LLLite") + logger.info("create ControlNet-LLLite") control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") - print(control_net) + logger.info(control_net) - # print number of parameters - print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + # logger.info number of parameters + logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") input() @@ -398,12 +401,12 @@ if __name__ == "__main__": # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -414,12 +417,12 @@ if __name__ == "__main__": scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") batch_size = 1 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 @@ -439,7 +442,7 @@ if __name__ == "__main__": scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(f"{sample_param}") # from safetensors.torch import save_file diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 02688001..65b3520c 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -6,7 +6,10 @@ import re from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -270,7 +273,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond # create module instances self.lllite_modules = apply_to_modules(self, target_modules) - print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") + logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") # def prepare_optimizer_params(self): def prepare_params(self): @@ -281,8 +284,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond train_params.append(p) else: non_train_params.append(p) - print(f"count of trainable parameters: {len(train_params)}") - print(f"count of non-trainable parameters: {len(non_train_params)}") + logger.info(f"count of trainable parameters: {len(train_params)}") + logger.info(f"count of non-trainable parameters: {len(non_train_params)}") for p in non_train_params: p.requires_grad_(False) @@ -388,7 +391,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond matches = pattern.findall(module_name) if matches is not None: for m in matches: - print(module_name, m) + logger.info(f"{module_name} {m}") module_name = module_name.replace(m, m.replace("_", "@")) module_name = module_name.replace("_", ".") module_name = module_name.replace("@", "_") @@ -407,7 +410,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond def replace_unet_linear_and_conv2d(): - print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") + logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") sdxl_original_unet.torch.nn.Linear = LLLiteLinear sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d @@ -419,10 +422,10 @@ if __name__ == "__main__": replace_unet_linear_and_conv2d() # test shape etc - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModelControlNetLLLite() - print("enable ControlNet-LLLite") + logger.info("enable ControlNet-LLLite") unet.apply_lllite(32, 64, None, False, 1.0) unet.to("cuda") # .to(torch.float16) @@ -439,14 +442,14 @@ if __name__ == "__main__": # unet_sd[converted_key] = model_sd[key] # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) - # print(info) + # logger.info(info) - # print(unet) + # logger.info(unet) - # print number of parameters + # logger.info number of parameters params = unet.prepare_params() - print("number of parameters", sum(p.numel() for p in params)) - # print("type any key to continue") + logger.info(f"number of parameters {sum(p.numel() for p in params)}") + # logger.info("type any key to continue") # input() unet.set_use_memory_efficient_attention(True, False) @@ -455,12 +458,12 @@ if __name__ == "__main__": # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -471,13 +474,13 @@ if __name__ == "__main__": scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 x = torch.randn(batch_size, 4, 128, 128).cuda() @@ -494,9 +497,9 @@ if __name__ == "__main__": scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(sample_param) # from safetensors.torch import save_file - # print("save weights") + # logger.info("save weights") # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d19..637f3345 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -12,10 +12,15 @@ import math import os import random -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import torch from torch import nn - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class DyLoRAModule(torch.nn.Module): """ @@ -165,7 +170,15 @@ class DyLoRAModule(torch.nn.Module): super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -182,6 +195,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) + if unit is not None: unit = int(unit) else: @@ -223,7 +237,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -267,11 +281,11 @@ class DyLoRANetwork(torch.nn.Module): self.apply_to_conv = apply_to_conv if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") if self.apply_to_conv: - print(f"apply LoRA to Conv2d with kernel size (3,3).") + logger.info("apply LoRA to Conv2d with kernel size (3,3).") # create module instances def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: @@ -306,9 +320,23 @@ class DyLoRANetwork(torch.nn.Module): lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + logger.info(f"create LoRA for Text Encoder {index}") + else: + index = None + logger.info("create LoRA for Text Encoder") + + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) - self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -316,7 +344,7 @@ class DyLoRANetwork(torch.nn.Module): target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras = create_modules(True, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -336,12 +364,12 @@ class DyLoRANetwork(torch.nn.Module): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -359,12 +387,12 @@ class DyLoRANetwork(torch.nn.Module): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -375,7 +403,7 @@ class DyLoRANetwork(torch.nn.Module): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") """ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0abee983..1184cd8a 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm from library import train_util, model_util import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name): if model_util.is_safetensors(file_name): @@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit): rank = value.size()[0] if rank > max_rank: max_rank = rank - print(f"Max rank: {max_rank}") + logger.info(f"Max rank: {max_rank}") rank = unit split_models = [] new_alpha = None while rank < max_rank: - print(f"Splitting rank {rank}") + logger.info(f"Splitting rank {rank}") new_sd = {} for key, value in lora_sd.items(): if "lora_down" in key: @@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit): # なぜかscaleするとおかしくなる…… # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] # scale = math.sqrt(this_rank / rank) # rank is > unit - # print(key, value.size(), this_rank, rank, value, scale) + # logger.info(key, value.size(), this_rank, rank, value, scale) # new_alpha = value * scale # always same # new_sd[key] = new_alpha new_sd[key] = value @@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit): def split(args): - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model) - print("Splitting Model...") + logger.info("Splitting Model...") original_rank, split_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") @@ -94,7 +97,7 @@ def split(args): filename, ext = os.path.splitext(args.save_to) model_file_name = filename + f"-{new_rank:04d}{ext}" - print(f"saving model to: {model_file_name}") + logger.info(f"saving model to: {model_file_name}") save_to_file(model_file_name, state_dict, new_metadata) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index dba7cd4e..43c1d005 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,10 +11,13 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, model_util, sdxl_model_util import lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) - -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-1 +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +32,24 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(args): +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + v2=None, + sdxl=None, + conv_dim=None, + v_parameterization=None, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + load_precision=None, + load_original_model_to=None, + load_tuned_model_to=None, +): def str_to_dtype(p): if p == "float": return torch.float @@ -39,44 +59,65 @@ def svd(args): return torch.bfloat16 return None - assert args.v2 != args.sdxl or ( - not args.v2 and not args.sdxl - ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if args.v_parameterization is None: - args.v_parameterization = args.v2 + assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + if v_parameterization is None: + v_parameterization = v2 - save_dtype = str_to_dtype(args.save_precision) + load_dtype = str_to_dtype(load_precision) if load_precision else None + save_dtype = str_to_dtype(save_precision) + work_device = "cpu" # load models - if not args.sdxl: - print(f"loading original SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + if not sdxl: + logger.info(f"loading original SD model : {model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + if load_dtype is not None: + text_encoder_o = text_encoder_o.to(load_dtype) + unet_o = unet_o.to(load_dtype) + + logger.info(f"loading tuned SD model : {model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] - model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + if load_dtype is not None: + text_encoder_t = text_encoder_t.to(load_dtype) + unet_t = unet_t.to(load_dtype) + + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - print(f"loading original SDXL model : {args.model_org}") + device_org = load_original_model_to if load_original_model_to else "cpu" + device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" + + logger.info(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {args.model_tuned}") + if load_dtype is not None: + text_encoder_o1 = text_encoder_o1.to(load_dtype) + text_encoder_o2 = text_encoder_o2.to(load_dtype) + unet_o = unet_o.to(load_dtype) + + logger.info(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) text_encoders_t = [text_encoder_t1, text_encoder_t2] + if load_dtype is not None: + text_encoder_t1 = text_encoder_t1.to(load_dtype) + text_encoder_t2 = text_encoder_t2.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha - if args.conv_dim is None: + if conv_dim is None: kwargs = {} else: - kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) + lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -88,50 +129,66 @@ def svd(args): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) + + # clear weight to save memory + module_o.weight = None + module_t.weight = None # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - diff = diff.float() diffs[lora_name] = diff + # clear target Text Encoder to save memory + for text_encoder in text_encoders_t: + del text_encoder + if not text_encoder_different: - print("Text encoder is same. Extract U-Net only.") + logger.warning("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] - diffs = {} + diffs = {} # clear diffs for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - diff = diff.float() + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - if args.device: - diff = diff.to(args.device) + # clear weight to save memory + module_o.weight = None + module_t.weight = None diffs[lora_name] = diff + # clear LoRA network, target U-Net to save memory + del lora_network_o + del lora_network_t + del unet_t + # make LoRA with svd - print("calculating by svd") + logger.info("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): - # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + if args.device: + mat = mat.to(args.device) + mat = mat.to(torch.float) # calc by float + + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) - rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim out_dim, in_dim = mat.size()[0:2] - if args.device: - mat = mat.to(args.device) + if device: + mat = mat.to(device) - # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -149,7 +206,7 @@ def svd(args): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) + hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) @@ -159,8 +216,8 @@ def svd(args): U = U.reshape(out_dim, rank, 1, 1) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - U = U.to("cpu").contiguous() - Vh = Vh.to("cpu").contiguous() + U = U.to(work_device, dtype=save_dtype).contiguous() + Vh = Vh.to(work_device, dtype=save_dtype).contiguous() lora_weights[lora_name] = (U, Vh) @@ -176,36 +233,34 @@ def svd(args): lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) - print(f"Loading extracted LoRA weights: {info}") + logger.info(f"Loading extracted LoRA weights: {info}") - dir_name = os.path.dirname(args.save_to) + dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # minimum metadata net_kwargs = {} - if args.conv_dim is not None: - net_kwargs["conv_dim"] = args.conv_dim - net_kwargs["conv_alpha"] = args.conv_dim + if conv_dim is not None: + net_kwargs["conv_dim"] = str(conv_dim) + net_kwargs["conv_alpha"] = str(float(conv_dim)) metadata = { - "ss_v2": str(args.v2), + "ss_v2": str(v2), "ss_base_model_version": model_version, "ss_network_module": "networks.lora", - "ss_network_dim": str(args.dim), - "ss_network_alpha": str(args.dim), + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), "ss_network_args": json.dumps(net_kwargs), } - if not args.no_metadata: - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title - ) + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) metadata.update(sai_metadata) - lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + lora_network_save.save_weights(save_to, save_dtype, metadata) + logger.info(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: @@ -213,13 +268,20 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument( "--v_parameterization", - type=bool, + action="store_true", default=None, help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", ) parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" ) + parser.add_argument( + "--load_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" + ) parser.add_argument( "--save_precision", type=str, @@ -231,16 +293,22 @@ def setup_parser() -> argparse.ArgumentParser: "--model_org", type=str, default=None, + required=True, help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", ) parser.add_argument( "--model_tuned", type=str, default=None, + required=True, help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -250,12 +318,37 @@ def setup_parser() -> argparse.ArgumentParser: help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", ) parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + parser.add_argument( + "--min_diff", + type=float, + default=0.01, + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--load_original_model_to", + type=str, + default=None, + help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) + parser.add_argument( + "--load_tuned_model_to", + type=str, + default=None, + help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) return parser @@ -264,4 +357,4 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() - svd(args) + svd(**vars(args)) diff --git a/networks/lora.py b/networks/lora.py index 0c75cd42..d1208040 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -11,7 +11,12 @@ from transformers import CLIPTextModel import numpy as np import torch import re +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -46,7 +51,7 @@ class LoRAModule(torch.nn.Module): # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -177,7 +182,7 @@ class LoRAInfModule(LoRAModule): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -216,7 +221,7 @@ class LoRAInfModule(LoRAModule): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info(f"default_forward {self.lora_name} {x.size()}") return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -242,13 +247,13 @@ class LoRAInfModule(LoRAModule): area = x.size()[1] mask = self.network.mask_dic.get(area, None) - if mask is None: - # raise ValueError(f"mask is None for resolution {area}") + if mask is None or len(x.size()) == 2: # emb_layers in SDXL doesn't have mask - # print(f"mask is None for resolution {area}, {x.size()}") + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts - if len(x.size()) != 4: + if len(x.size()) == 3: mask = torch.reshape(mask, (1, -1, 1)) return mask @@ -263,6 +268,8 @@ class LoRAInfModule(LoRAModule): lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) lx = lx * mask x = self.org_forward(x) @@ -291,7 +298,7 @@ class LoRAInfModule(LoRAModule): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") return query def sub_prompt_forward(self, x): @@ -306,7 +313,7 @@ class LoRAInfModule(LoRAModule): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -314,7 +321,7 @@ class LoRAInfModule(LoRAModule): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -332,7 +339,7 @@ class LoRAInfModule(LoRAModule): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -351,7 +358,7 @@ class LoRAInfModule(LoRAModule): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") # if num_sub_prompts > num of LoRAs, fill with zero for i in range(len(masks)): if masks[i] is None: @@ -374,7 +381,7 @@ class LoRAInfModule(LoRAModule): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") return out @@ -511,7 +518,9 @@ def get_block_dims_and_alphas( len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -520,7 +529,7 @@ def get_block_dims_and_alphas( len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -540,13 +549,13 @@ def get_block_dims_and_alphas( else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -586,7 +595,7 @@ def get_block_lr_weight( elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -598,14 +607,14 @@ def get_block_lr_weight( up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -613,24 +622,24 @@ def get_block_lr_weight( up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -711,7 +720,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -786,20 +795,26 @@ class LoRANetwork(torch.nn.Module): self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -884,15 +899,15 @@ class LoRANetwork(torch.nn.Module): for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -900,15 +915,15 @@ class LoRANetwork(torch.nn.Module): target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -926,6 +941,10 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -939,12 +958,12 @@ class LoRANetwork(torch.nn.Module): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -966,12 +985,12 @@ class LoRANetwork(torch.nn.Module): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -982,7 +1001,7 @@ class LoRANetwork(torch.nn.Module): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1113,7 +1132,7 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.set_network(self) - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): self.batch_size = batch_size self.num_sub_prompts = num_sub_prompts self.current_size = (height, width) @@ -1128,7 +1147,7 @@ class LoRANetwork(torch.nn.Module): device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m @@ -1139,6 +1158,13 @@ class LoRANetwork(torch.nn.Module): resize_add(h, w) if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + h = (h + 1) // 2 w = (w + 1) // 2 diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4..b99b0244 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -9,8 +9,15 @@ from diffusers import UNet2DConditionModel import numpy as np from tqdm import tqdm from transformers import CLIPTextModel -import torch +import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -248,7 +255,7 @@ def create_network_from_weights( elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -291,12 +298,12 @@ class LoRANetwork(torch.nn.Module): super().__init__() self.multiplier = multiplier - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") # convert SDXL Stability AI's U-Net modules to Diffusers converted = self.convert_unet_modules(modules_dim, modules_alpha) if converted: - print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") + logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") # create module instances def create_modules( @@ -331,7 +338,7 @@ class LoRANetwork(torch.nn.Module): lora_name = lora_name.replace(".", "_") if lora_name not in modules_dim: - # print(f"skipped {lora_name} (not found in modules_dim)") + # logger.info(f"skipped {lora_name} (not found in modules_dim)") skipped.append(lora_name) continue @@ -362,18 +369,18 @@ class LoRANetwork(torch.nn.Module): text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") if len(skipped_te) > 0: - print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") + logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") # extend U-Net target modules to include Conv2d 3x3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras: List[LoRAModule] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") if len(skipped_un) > 0: - print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") + logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") # assertion names = set() @@ -420,11 +427,11 @@ class LoRANetwork(torch.nn.Module): def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") for lora in self.text_encoder_loras: lora.apply_to(multiplier) if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") for lora in self.unet_loras: lora.apply_to(multiplier) @@ -433,16 +440,16 @@ class LoRANetwork(torch.nn.Module): lora.unapply_to() def merge_to(self, multiplier=1.0): - print("merge LoRA weights to original weights") + logger.info("merge LoRA weights to original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.merge_to(multiplier) - print(f"weights are merged") + logger.info(f"weights are merged") def restore_from(self, multiplier=1.0): - print("restore LoRA weights from original weights") + logger.info("restore LoRA weights from original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.restore_from(multiplier) - print(f"weights are restored") + logger.info(f"weights are restored") def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # convert SDXL Stability AI's state dict to Diffusers' based state dict @@ -463,7 +470,7 @@ class LoRANetwork(torch.nn.Module): my_state_dict = self.state_dict() for key in state_dict.keys(): if state_dict[key].size() != my_state_dict[key].size(): - # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") + # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") state_dict[key] = state_dict[key].view(my_state_dict[key].size()) return super().load_state_dict(state_dict, strict) @@ -476,7 +483,7 @@ if __name__ == "__main__": from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") @@ -490,7 +497,7 @@ if __name__ == "__main__": image_prefix = args.model_id.replace("/", "_") + "_" # load Diffusers model - print(f"load model from {args.model_id}") + logger.info(f"load model from {args.model_id}") pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] if args.sdxl: # use_safetensors=True does not work with 0.18.2 @@ -503,7 +510,7 @@ if __name__ == "__main__": text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] # load LoRA weights - print(f"load LoRA weights from {args.lora_weights}") + logger.info(f"load LoRA weights from {args.lora_weights}") if os.path.splitext(args.lora_weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -512,10 +519,10 @@ if __name__ == "__main__": lora_sd = torch.load(args.lora_weights) # create by LoRA weights and load weights - print(f"create LoRA network") + logger.info(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) - print(f"load LoRA network weights") + logger.info(f"load LoRA network weights") lora_network.load_state_dict(lora_sd) lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this @@ -544,34 +551,34 @@ if __name__ == "__main__": random.seed(seed) # create image with original weights - print(f"create image with original weights") + logger.info(f"create image with original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "original.png") # apply LoRA network to the model: slower than merge_to, but can be reverted easily - print(f"apply LoRA network to the model") + logger.info(f"apply LoRA network to the model") lora_network.apply_to(multiplier=1.0) - print(f"create image with applied LoRA") + logger.info(f"create image with applied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "applied_lora.png") # unapply LoRA network to the model - print(f"unapply LoRA network to the model") + logger.info(f"unapply LoRA network to the model") lora_network.unapply_to() - print(f"create image with unapplied LoRA") + logger.info(f"create image with unapplied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unapplied_lora.png") # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) - print(f"merge LoRA network to the model") + logger.info(f"merge LoRA network to the model") lora_network.merge_to(multiplier=1.0) - print(f"create image with LoRA") + logger.info(f"create image with LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "merged_lora.png") @@ -579,31 +586,31 @@ if __name__ == "__main__": # restore (unmerge) LoRA weights: numerically unstable # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない # 保存したstate_dictから元の重みを復元するのが確実 - print(f"restore (unmerge) LoRA weights") + logger.info(f"restore (unmerge) LoRA weights") lora_network.restore_from(multiplier=1.0) - print(f"create image without LoRA") + logger.info(f"create image without LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unmerged_lora.png") # restore original weights - print(f"restore original weights") + logger.info(f"restore original weights") pipe.unet.load_state_dict(org_unet_sd) pipe.text_encoder.load_state_dict(org_text_encoder_sd) if args.sdxl: pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) - print(f"create image with restored original weights") + logger.info(f"create image with restored original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "restore_original.png") # use convenience function to merge LoRA weights - print(f"merge LoRA weights with convenience function") + logger.info(f"merge LoRA weights with convenience function") merge_lora_weights(pipe, lora_sd, multiplier=1.0) - print(f"create image with merged LoRA weights") + logger.info(f"create image with merged LoRA weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py index a357d7f7..919222ce 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -14,7 +14,10 @@ from transformers import CLIPTextModel import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module): # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info("default_forward", self.lora_name, x.size()) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) lx = lx * mask x = self.org_forward(x) @@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) return query def sub_prompt_forward(self, x): @@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # for i in range(len(masks)): # if masks[i] is None: # masks[i] = torch.zeros_like(masks[-1]) @@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond) return out @@ -526,7 +529,7 @@ def get_block_dims_and_alphas( len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -535,7 +538,7 @@ def get_block_dims_and_alphas( len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -555,13 +558,13 @@ def get_block_dims_and_alphas( else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -601,7 +604,7 @@ def get_block_lr_weight( elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -613,14 +616,14 @@ def get_block_lr_weight( up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -628,24 +631,24 @@ def get_block_lr_weight( up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module): self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module): for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module): target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module): device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd..6aaa5810 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -5,27 +5,34 @@ from library import model_util import library.train_util as train_util import argparse from transformers import CLIPTokenizer + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() import library.model_util as model_util import lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = get_preferred_device() def interrogate(args): weights_dtype = torch.float16 # いろいろ準備する - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") args.pretrained_model_name_or_path = args.sd_model args.vae = None text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) - print(f"loading LoRA: {args.model}") + logger.info(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい @@ -35,11 +42,11 @@ def interrogate(args): has_te_weight = True break if not has_te_weight: - print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") + logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae - print("loading tokenizer") + logger.info("loading tokenizer") if args.v2: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: @@ -53,7 +60,7 @@ def interrogate(args): # トークンをひとつひとつ当たっていく token_id_start = 0 token_id_end = max(tokenizer.all_special_ids) - print(f"interrogate tokens are: {token_id_start} to {token_id_end}") + logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") def get_all_embeddings(text_encoder): embs = [] @@ -79,24 +86,24 @@ def interrogate(args): embs.extend(encoder_hidden_states) return torch.stack(embs) - print("get original text encoder embeddings.") + logger.info("get original text encoder embeddings.") orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) info = network.load_state_dict(weights_sd, strict=False) - print(f"Loading LoRA weights: {info}") + logger.info(f"Loading LoRA weights: {info}") network.to(DEVICE, dtype=weights_dtype) network.eval() del unet - print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") - print("get text encoder embeddings with lora.") + logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") + logger.info("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) # 比べる:とりあえず単純に差分の絶対値で - print("comparing...") + logger.info("comparing...") diffs = {} for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): diff = torch.mean(torch.abs(orig_emb - lora_emb)) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 71492621..fea8a3f3 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "alpha" in key: continue @@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -239,7 +242,7 @@ def merge(args): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) @@ -264,18 +267,18 @@ def merge(args): ) if args.v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint( args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -289,12 +292,12 @@ def merge(args): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index ffd6b2b4..334d127b 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -6,7 +6,10 @@ import torch from safetensors.torch import load_file, save_file import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': @@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = None dim = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if 'alpha' in key: if key in merged_sd: @@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype): dim = lora_sd[key].size()[0] merged_sd[key] = lora_sd[key] * ratio - print(f"dim (rank): {dim}, alpha: {alpha}") + logger.info(f"dim (rank): {dim}, alpha: {alpha}") if alpha is None: alpha = dim @@ -142,19 +145,21 @@ def merge(args): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"\nsaving SD model to: {args.save_to}") + logger.info("") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + logger.info(f"") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/oft.py b/networks/oft.py index 1d088f87..461a9869 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -8,7 +8,10 @@ from transformers import CLIPTextModel import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -237,7 +240,7 @@ class OFTNetwork(torch.nn.Module): self.dim = dim self.alpha = alpha - print( + logger.info( f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" ) @@ -258,7 +261,7 @@ class OFTNetwork(torch.nn.Module): if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): oft_name = prefix + "." + name + "." + child_name oft_name = oft_name.replace(".", "_") - # print(oft_name) + # logger.info(oft_name) oft = module_class( oft_name, @@ -279,7 +282,7 @@ class OFTNetwork(torch.nn.Module): target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) - print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") # assertion names = set() @@ -316,7 +319,7 @@ class OFTNetwork(torch.nn.Module): # TODO refactor to common function with apply_to def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - print("enable OFT for U-Net") + logger.info("enable OFT for U-Net") for oft in self.unet_ofts: sd_for_lora = {} @@ -326,7 +329,7 @@ class OFTNetwork(torch.nn.Module): oft.load_state_dict(sd_for_lora, False) oft.merge_to() - print(f"weights are merged") + logger.info(f"weights are merged") # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): @@ -338,11 +341,11 @@ class OFTNetwork(torch.nn.Module): for oft in ofts: params.extend(oft.parameters()) - # print num of params + # logger.info num of params num_params = 0 for p in params: num_params += p.numel() - print(f"OFT params: {num_params}") + logger.info(f"OFT params: {num_params}") return params param_data = {"params": enumerate_params(self.unet_ofts)} diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 03fc545e..7df7ef0c 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -2,80 +2,86 @@ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py # Thanks to cloneofsimo +import os import argparse import torch from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm -from library import train_util, model_util import numpy as np +from library import train_util +from library import model_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + MIN_SV = 1e-6 # Model save and load functions + def load_state_dict(file_name, dtype): - if model_util.is_safetensors(file_name): - sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() - else: - sd = torch.load(file_name, map_location='cpu') - metadata = None + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) - return sd, metadata + return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) - else: - torch.save(model, file_name) +def save_to_file(file_name, state_dict, metadata): + if model_util.is_safetensors(file_name): + save_file(state_dict, file_name, metadata) + else: + torch.save(state_dict, file_name) # Indexing functions -def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0)/original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S)-1)) - return index +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index def index_sv_fro(S, target): - S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S)-1)) + S_squared = S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) - return index + return index def index_sv_ratio(S, target): - max_sv = S[0] - min_sv = max_sv/target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S)-1)) + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) - return index + return index # Modified from Kohaku-blueleaf's extract/merge functions def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size, kernel_size, _ = weight.size() U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] @@ -92,17 +98,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size = weight.size() - + U, S, Vh = torch.linalg.svd(weight.to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] - + U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() del U, S, Vh, weight @@ -113,7 +119,7 @@ def merge_conv(lora_down, lora_up, device): in_rank, in_size, kernel_size, k_ = lora_down.shape out_size, out_rank, _, _ = lora_up.shape assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) @@ -127,236 +133,280 @@ def merge_linear(lora_down, lora_up, device): in_rank, in_size = lora_down.shape out_size, out_rank = lora_up.shape assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) - + weight = lora_up @ lora_down del lora_up, lora_down return weight - + # Calculate new rank + def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} - if dynamic_method=="sv_ratio": + if dynamic_method == "sv_ratio": # Calculate new dim and alpha based off ratio new_rank = index_sv_ratio(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - elif dynamic_method=="sv_cumulative": + elif dynamic_method == "sv_cumulative": # Calculate new dim and alpha based off cumulative sum new_rank = index_sv_cumulative(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - elif dynamic_method=="sv_fro": + elif dynamic_method == "sv_fro": # Calculate new dim and alpha based off sqrt sum of squares new_rank = index_sv_fro(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) else: new_rank = rank - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - - if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 new_rank = 1 - new_alpha = float(scale*new_rank) - elif new_rank > rank: # cap max rank at rank + new_alpha = float(scale * new_rank) + elif new_rank > rank: # cap max rank at rank new_rank = rank - new_alpha = float(scale*new_rank) - + new_alpha = float(scale * new_rank) # Calculate resize info s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - + S_squared = S.pow(2) s_fro = torch.sqrt(torch.sum(S_squared)) s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro/s_fro) + fro_percent = float(s_red_fro / s_fro) param_dict["new_rank"] = new_rank param_dict["new_alpha"] = new_alpha - param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["sum_retained"] = (s_rank) / s_sum param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0]/S[new_rank - 1] + param_dict["max_ratio"] = S[0] / S[new_rank - 1] return param_dict -def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): - network_alpha = None - network_dim = None - verbose_str = "\n" - fro_list = [] +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and "alpha" in key: + network_alpha = value + if network_dim is None and "lora_down" in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim - scale = network_alpha/network_dim + scale = network_alpha / network_dim - if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + if dynamic_method: + logger.info( + f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" + ) - lora_down_weight = None - lora_up_weight = None + lora_down_weight = None + lora_up_weight = None - o_lora_sd = lora_sd.copy() - block_down_name = None - block_up_name = None + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None - with torch.no_grad(): - for key, value in tqdm(lora_sd.items()): - weight_name = None - if 'lora_down' in key: - block_down_name = key.rsplit('.lora_down', 1)[0] - weight_name = key.rsplit(".", 1)[-1] - lora_down_weight = value - else: - continue + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if "lora_down" in key: + block_down_name = key.rsplit(".lora_down", 1)[0] + weight_name = key.rsplit(".", 1)[-1] + lora_down_weight = value + else: + continue - # find corresponding lora_up and alpha - block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + '.alpha', None) + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + ".alpha", None) - weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + weights_loaded = lora_down_weight is not None and lora_up_weight is not None - if weights_loaded: + if weights_loaded: - conv2d = (len(lora_down_weight.size()) == 4) - if lora_alpha is None: - scale = 1.0 - else: - scale = lora_alpha/lora_down_weight.size()[0] + conv2d = len(lora_down_weight.size()) == 4 + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha / lora_down_weight.size()[0] - if conv2d: - full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - else: - full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - if verbose: - max_ratio = param_dict['max_ratio'] - sum_retained = param_dict['sum_retained'] - fro_retained = param_dict['fro_retained'] - if not np.isnan(fro_retained): - fro_list.append(float(fro_retained)) + if verbose: + max_ratio = param_dict["max_ratio"] + sum_retained = param_dict["sum_retained"] + fro_retained = param_dict["fro_retained"] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) - verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + verbose_str += f"{block_down_name:75} | " + verbose_str += ( + f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + ) - if verbose and dynamic_method: - verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" - else: - verbose_str+=f"\n" + if verbose and dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str += "\n" - new_alpha = param_dict['new_alpha'] - o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + new_alpha = param_dict["new_alpha"] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) - block_down_name = None - block_up_name = None - lora_down_weight = None - lora_up_weight = None - weights_loaded = False - del param_dict + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict - if verbose: - print(verbose_str) - - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - print("resizing complete") - return o_lora_sd, network_dim, new_alpha + if verbose: + print(verbose_str) + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") + return o_lora_sd, network_dim, new_alpha def resize(args): - if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')): - raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + if args.save_to is None or not ( + args.save_to.endswith(".ckpt") + or args.save_to.endswith(".pt") + or args.save_to.endswith(".pth") + or args.save_to.endswith(".safetensors") + ): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") - - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None + args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank - if args.dynamic_method and not args.dynamic_param: - raise Exception("If using dynamic_method, then dynamic_param is required") + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") - print("loading Model...") - lora_sd, metadata = load_state_dict(args.model, merge_dtype) + merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype - print("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) + logger.info("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) - # update metadata - if metadata is None: - metadata = {} + logger.info("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model( + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + ) - comment = metadata.get("ss_training_comment", "") + # update metadata + if metadata is None: + metadata = {} - if not args.dynamic_method: - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) - else: - metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - metadata["ss_network_dim"] = 'Dynamic' - metadata["ss_network_alpha"] = 'Dynamic' + comment = metadata.get("ss_training_comment", "") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash + if not args.dynamic_method: + conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})" + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = ( + f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + ) + metadata["ss_network_dim"] = "Dynamic" + metadata["ss_network_alpha"] = "Dynamic" - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") - parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--model", type=str, default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument("--verbose", action="store_true", - help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") - parser.add_argument("--dynamic_param", type=float, default=None, - help="Specify target for dynamic reduction") - - return parser + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", + ) + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--new_conv_rank", + type=int, + default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" + ) + parser.add_argument( + "--dynamic_method", + type=str, + default=None, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", + ) + parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - resize(args) + args = parser.parse_args() + resize(args) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index c513eb59..b147eb44 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -1,13 +1,23 @@ +import itertools import math import argparse import os import time +import concurrent.futures import torch from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft +from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26 +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): @@ -25,36 +35,58 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, model, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(model, file_name, metadata=metadata) else: torch.save(model, file_name) -def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): - text_encoder1.to(merge_dtype) +def detect_method_from_training_model(models, dtype): + for model in models: + # TODO It is better to use key names to detect the method + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if "lora_up" in key or "lora_down" in key: + return "LoRA" + elif "oft_blocks" in key: + return "OFT" + + +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype): text_encoder1.to(merge_dtype) + text_encoder2.to(merge_dtype) unet.to(merge_dtype) + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + logger.info(f"method:{method}") + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == "LoRA": + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == "OFT": + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): @@ -65,65 +97,172 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - for model, ratio in zip(models, ratios): - print(f"loading: {model}") + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" + logger.info(f"merging...") - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + + if method == "LoRA": + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + elif method == "OFT": + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") - continue + logger.info(f"no module found for OFT weight: {key}") + return module = name_to_module[module_name] - # print(f"apply {key} to {module}") - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] + # logger.info(f"apply {key} to {module}") - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim + oft_blocks = lora_sd[key] - # W <- W + U * D - weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels + + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + + multiplier = 1 + if lbw: + index = get_lbw_block_index(key, False) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + multiplier *= lbw_weights[index] + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) + + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor module.weight = torch.nn.Parameter(weight) + # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + +def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + if method == "OFT": + raise ValueError( + "OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません" + ) + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + merged_sd = {} v2 = None base_model = None - for model, ratio in zip(models, ratios): - print(f"loading: {model}") + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if lora_metadata is not None: if v2 is None: v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず @@ -154,14 +293,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue - + if "lora_up" in key and concat: concat_dim = 1 elif "lora_down" in key and concat: @@ -175,8 +314,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -198,10 +343,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): dim = merged_sd[key_down].shape[0] perm = torch.randperm(dim) merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -226,7 +371,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -243,7 +396,7 @@ def merge(args): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") ( text_model1, @@ -254,7 +407,7 @@ def merge(args): ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") - merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype) if args.no_metadata: sai_metadata = None @@ -265,14 +418,20 @@ def merge(args): None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") sdxl_model_util.save_stable_diffusion_checkpoint( args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -286,8 +445,8 @@ def merge(args): ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -313,12 +472,19 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument( "--no_metadata", action="store_true", @@ -334,8 +500,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) return parser diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b3..c79b45ac 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,6 +1,8 @@ -import math import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -8,10 +10,196 @@ from tqdm import tqdm from library import sai_model_spec, train_util import library.model_util as model_util import lora +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, + "IN01": False, + "IN02": False, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, + "MID": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": False, + "OUT07": False, + "OUT08": False, + "OUT09": False, + "OUT10": False, + "OUT11": False, +} + +LAYER17 = { + "BASE": True, + "IN00": False, + "IN01": True, + "IN02": True, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, + "MID": True, + "OUT00": False, + "OUT01": False, + "OUT02": False, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, + "MID": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": False, + "OUT10": False, + "OUT11": False, +} + +LAYER26 = { + "BASE": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": True, + "IN10": True, + "IN11": True, + "MID": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 + block_idx = -1 # invalid lora name + if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + up_down = g[0] + i = int(g[1]) + j = int(g[3]) + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name + + if g[0] == "down": + block_idx = 1 + idx # 1-based index, down block index + elif g[0] == "up": + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + + elif "mid_block_" in lora_name: + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block + else: + # SDXL: some numbers are skipped + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 13 + block_idx = 13 + elif name.startswith("output_blocks_"): # 0-8 to 14-22 + block_idx = 14 + int(name.split("_")[2]) + elif name.startswith("out_"): # 23, No LoRA in sd-scripts + block_idx = 23 + + return block_idx + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -28,25 +216,54 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(state_dict, file_name, metadata=metadata) else: torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") +def format_lbws(lbws): + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all( + len(lbw) in ACCEPTABLE for lbw in lbws + ), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all( + all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws + ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + return lbws, is_sdxl, LBW_TARGET_IDX + + +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): + logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None - for model, ratio in zip(models, ratios): - print(f"loading: {model}") + + if lbws: + lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws) + else: + is_sdxl = False + LBW_TARGET_IDX = [] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -55,8 +272,14 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): if "lora_down" not in key: continue @@ -73,15 +296,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty out_dim = up_weight.size()[0] conv2d = len(down_weight.size()) == 4 kernel_size = None if not conv2d else down_weight.size()[2:4] - # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # make original weight if not exist if lora_module_name not in merged_sd: weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) - if device: - weight = weight.to(device) else: weight = merged_sd[lora_module_name] + if device: + weight = weight.to(device) # merge to weight if device: @@ -91,6 +314,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -107,13 +336,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale - merged_sd[lora_module_name] = weight + merged_sd[lora_module_name] = weight.to("cpu") # extract from merged weights - print("extract new lora...") + logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): + if device: + mat = mat.to(device) + conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) @@ -152,7 +384,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) + merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu") # build minimum metadata dims = f"{new_rank}" @@ -167,7 +399,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -185,10 +425,16 @@ def merge(args): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) - print(f"calculating hashes and creating metadata...") + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -203,13 +449,13 @@ def merge(args): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -229,12 +475,19 @@ def setup_parser() -> argparse.ArgumentParser: help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank", @@ -242,7 +495,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) parser.add_argument( "--no_metadata", action="store_true", diff --git a/requirements.txt b/requirements.txt index c27131cd..977c5cd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,24 @@ -accelerate==0.23.0 -transformers==4.30.2 -diffusers[torch]==0.21.2 +accelerate==0.25.0 +transformers==4.36.2 +diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 -einops==0.6.0 +opencv-python==4.8.1.78 +einops==0.7.0 pytorch-lightning==1.9.0 -# bitsandbytes==0.39.1 -tensorboard==2.10.1 -safetensors==0.3.1 +bitsandbytes==0.43.0 +prodigyopt==1.0 +lion-pytorch==0.0.6 +tensorboard +safetensors==0.4.2 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.15.1 +huggingface-hub==0.20.1 +# for Image utils +imagesize==1.4.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 @@ -22,12 +26,17 @@ huggingface-hub==0.15.1 # for WD14 captioning (tensorflow) # tensorflow==2.10.1 # for WD14 captioning (onnx) -# onnx==1.14.1 -# onnxruntime-gpu==1.16.0 -# onnxruntime==1.16.0 +# onnx==1.15.0 +# onnxruntime-gpu==1.17.1 +# onnxruntime==1.17.1 +# for cuda 12.1(default 11.8) +# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + # this is for onnx: # protobuf==3.20.3 # open clip for SDXL -open-clip-torch==2.20.0 +# open-clip-torch==2.20.0 +# For logging +rich==13.7.0 # for kohya_ss library -e . diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c31ae007..d52f85a8 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -16,17 +16,11 @@ import re import diffusers import numpy as np + import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device +init_ipex() -try: - import intel_extension_for_pytorch as ipex - - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, @@ -57,9 +51,16 @@ import library.train_util as train_util import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -81,12 +82,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -94,7 +95,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -111,7 +112,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -167,7 +168,7 @@ def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -223,7 +224,7 @@ def replace_vae_attn_to_xformers(): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -290,7 +291,7 @@ class PipelineLike: vae: AutoencoderKL, text_encoders: List[CLIPTextModel], tokenizers: List[CLIPTokenizer], - unet: SdxlUNet2DConditionModel, + unet: InferSdxlUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, ): @@ -328,7 +329,7 @@ class PipelineLike: self.vae = vae self.text_encoders = text_encoders self.tokenizers = tokenizers - self.unet: SdxlUNet2DConditionModel = unet + self.unet: InferSdxlUNet2DConditionModel = unet self.scheduler = scheduler self.safety_checker = None @@ -345,6 +346,8 @@ class PipelineLike: self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids @@ -357,7 +360,7 @@ class PipelineLike: token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) + # logger.info("replace_tokens", tokens, "=>", token_replacements) if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() @@ -375,6 +378,14 @@ class PipelineLike: def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + @torch.no_grad() def __call__( self, @@ -449,7 +460,7 @@ class PipelineLike: do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -504,7 +515,8 @@ class PipelineLike: uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -552,7 +564,7 @@ class PipelineLike: text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -567,9 +579,11 @@ class PipelineLike: text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) @@ -642,8 +656,7 @@ class PipelineLike: init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -706,7 +719,116 @@ class PipelineLike: control_net.set_cond_image(None) each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + # # first, we downscale the latents to the half of the size + # # 最初に1/2に縮小する + # height, width = latents.shape[-2:] + # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( + # # latents.dtype + # # ) + # latents = latents[:, :, ::2, ::2] + # current_scale = 0.5 + + # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) + # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) + # scale_step = 0.125 + + # # timesteps at which to start increasing the scale: 1000 seems to be enough + # # 拡大を開始するtimesteps: 1000で十分そうである + # start_timesteps = 1000 + + # # how many steps to wait before increasing the scale again + # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) + # # large values leads to flat images + + # # 何ステップごとに拡大するか + # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) + # # 大きすぎると細部が書き込まれずのっぺりした感じになる + # every_n_steps = 5 + + # scale_step = input("scale step:") + # scale_step = float(scale_step) + # start_timesteps = input("start timesteps:") + # start_timesteps = int(start_timesteps) + # every_n_steps = input("every n steps:") + # every_n_steps = int(every_n_steps) + + # # for i, t in enumerate(tqdm(timesteps)): + # i = 0 + # last_step = 0 + # while i < len(timesteps): + # t = timesteps[i] + # print(f"[{i}] t={t}") + + # print(i, t, current_scale, latents.shape) + # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: + # if i == last_step: + # pass + # else: + # print("upscale") + # current_scale = min(current_scale + scale_step, 1.0) + + # h = int(height * current_scale) // 8 * 8 + # w = int(width * current_scale) // 8 * 8 + + # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( + # latents.dtype + # ) + # last_step = i + # i = max(0, i - every_n_steps + 1) + + # diff = timesteps[i] - timesteps[last_step] + # # resized_init_noise = torch.nn.functional.interpolate( + # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False + # # ).to(latents.dtype) + # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) + # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) + # # latents += torch.randn_like(latents) / 100 * diff + # continue + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -717,7 +839,7 @@ class PipelineLike: if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -775,6 +897,8 @@ class PipelineLike: if is_cancelled_callback is not None and is_cancelled_callback(): return None + i += 1 + if return_latents: return latents @@ -782,8 +906,7 @@ class PipelineLike: if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -798,8 +921,7 @@ class PipelineLike: # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if output_type == "pil": # image = self.numpy_to_pil(image) @@ -937,7 +1059,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -967,7 +1089,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1240,7 +1362,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1308,9 +1430,8 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): # endregion - # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -1325,6 +1446,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -1371,6 +1493,7 @@ def main(args): (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1379,7 +1502,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) # schedulerを用意する @@ -1407,7 +1530,7 @@ def main(args): scheduler_module = diffusers.schedulers.scheduling_euler_discrete has_clip_sample = False elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete has_clip_sample = False elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": @@ -1453,7 +1576,7 @@ def main(args): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -1462,7 +1585,7 @@ def main(args): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1494,11 +1617,11 @@ def main(args): # ↓以下は結局PipeでFalseに設定されるので意味がなかった # # clip_sample=Trueにする # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") # scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -1523,13 +1646,17 @@ def main(args): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) + vae.eval() text_encoder1.to(dtype).to(device) text_encoder2.to(dtype).to(device) unet.to(dtype).to(device) + text_encoder1.eval() + text_encoder2.eval() + unet.eval() # networkを組み込む if args.network_module: @@ -1544,10 +1671,10 @@ def main(args): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1565,7 +1692,7 @@ def main(args): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1573,7 +1700,7 @@ def main(args): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs @@ -1583,20 +1710,20 @@ def main(args): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1610,7 +1737,7 @@ def main(args): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1619,7 +1746,7 @@ def main(args): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1636,7 +1763,7 @@ def main(args): # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1667,7 +1794,7 @@ def main(args): control_nets.append((control_net, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder1.to(memory_format=torch.channels_last) text_encoder2.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1691,11 +1818,38 @@ def main(args): args.clip_skip, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -1729,7 +1883,7 @@ def main(args): token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -1759,10 +1913,10 @@ def main(args): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -1788,7 +1942,7 @@ def main(args): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -1804,14 +1958,14 @@ def main(args): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -1819,22 +1973,22 @@ def main(args): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -1863,17 +2017,17 @@ def main(args): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -1898,14 +2052,16 @@ def main(args): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -1931,7 +2087,7 @@ def main(args): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 @@ -1943,7 +2099,7 @@ def main(args): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -1988,7 +2144,7 @@ def main(args): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2034,7 +2190,7 @@ def main(args): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), ( width, height, @@ -2056,6 +2212,7 @@ def main(args): prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2086,11 +2243,16 @@ def main(args): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2154,7 +2316,7 @@ def main(args): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2188,8 +2350,8 @@ def main(args): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -2205,6 +2367,8 @@ def main(args): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) metadata.add_text("original-height", str(original_height)) metadata.add_text("original-width", str(original_width)) metadata.add_text("original-height-negative", str(original_height_negative)) @@ -2233,7 +2397,9 @@ def main(args): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -2246,7 +2412,8 @@ def main(args): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -2286,76 +2453,91 @@ def main(args): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2364,25 +2546,25 @@ def main(args): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2390,12 +2572,161 @@ def main(args): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) # prepare seed if seeds is not None: # given in prompt @@ -2407,7 +2738,7 @@ def main(args): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.error("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seeds = iter_seed @@ -2417,7 +2748,7 @@ def main(args): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2433,7 +2764,7 @@ def main(args): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2458,7 +2789,9 @@ def main(args): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -2493,18 +2826,25 @@ def main(args): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -2516,7 +2856,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -2527,10 +2869,16 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument( - "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", ) parser.add_argument( - "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", ) parser.add_argument( "--original_height_negative", @@ -2544,8 +2892,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", ) - parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") - parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") parser.add_argument( "--vae_batch_size", @@ -2559,7 +2911,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", ) - parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument( "--sampler", @@ -2591,9 +2945,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -2624,25 +2983,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -2657,7 +3037,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -2674,7 +3056,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -2683,7 +3068,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -2691,7 +3078,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -2706,11 +3096,18 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( - "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", ) # parser.add_argument( # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" @@ -2734,6 +3131,70 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) @@ -2745,4 +3206,5 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd6..a1e93b7f 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -8,23 +8,28 @@ import os import random from einops import repeat import numpy as np + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler from PIL import Image -import open_clip + +# import open_clip from safetensors.torch import load_file from library import model_util, sdxl_model_util import networks.lora as lora +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 @@ -87,7 +92,7 @@ if __name__ == "__main__": guidance_scale = 7 seed = None # 1 - DEVICE = "cuda" + DEVICE = get_preferred_device() DTYPE = torch.float16 # bfloat16 may work parser = argparse.ArgumentParser() @@ -142,7 +147,7 @@ if __name__ == "__main__": vae_dtype = DTYPE if DTYPE == torch.float16: - print("use float32 for vae") + logger.info("use float32 for vae") vae_dtype = torch.float32 vae.to(DEVICE, dtype=vae_dtype) vae.eval() @@ -153,12 +158,13 @@ if __name__ == "__main__": text_model2.eval() unet.set_use_memory_efficient_attention(True, False) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(True) # Tokenizers tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) - tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + # tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name) # LoRA for weights_file in args.lora_weights: @@ -189,9 +195,11 @@ if __name__ == "__main__": emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # print("emb1", emb1.shape) + # logger.info("emb1", emb1.shape) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) - uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + uc_vector = c_vector.clone().to( + DEVICE, dtype=DTYPE + ) # ちょっとここ正しいかどうかわからない I'm not sure if this is right # crossattn @@ -214,13 +222,22 @@ if __name__ == "__main__": # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい # text encoder 2 - with torch.no_grad(): - tokens = tokenizer2(text2).to(DEVICE) + # tokens = tokenizer2(text2).to(DEVICE) + tokens = tokenizer2( + text, + truncation=True, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(DEVICE) + with torch.no_grad(): enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] - # print("hidden_states2", text_embedding2_penu.shape) - text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion + # logger.info("hidden_states2", text_embedding2_penu.shape) + text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion # 連結して終了 concat and finish text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) @@ -228,7 +245,7 @@ if __name__ == "__main__": # cond c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) - # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond @@ -325,4 +342,4 @@ if __name__ == "__main__": seed = int(seed) generate_image(prompt, prompt2, negative_prompt, seed) - print("Done!") + logger.info("Done!") diff --git a/sdxl_train.py b/sdxl_train.py index fd775624..46d7860b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1,7 +1,6 @@ # training with captions import argparse -import gc import math import os from multiprocessing import Value @@ -9,22 +8,26 @@ from typing import List import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory_on_device -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import sdxl_model_util +from library import deepspeed_utils, sdxl_model_util import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -38,6 +41,7 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -96,8 +100,12 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) - assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -120,20 +128,20 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -144,7 +152,7 @@ def train(args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -174,7 +182,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -190,7 +198,7 @@ def train(args): ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -257,9 +265,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -352,8 +358,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -368,7 +374,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -394,26 +402,40 @@ def train(args): text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) - # acceleratorがなんかよろしくやってくれるらしい - if train_unet: - unet = accelerator.prepare(unet) - (unet,) = train_util.transform_models_if_DDP([unet]) + # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer if train_text_encoder1: - text_encoder1 = accelerator.prepare(text_encoder1) - (text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1]) - if train_text_encoder2: - text_encoder2 = accelerator.prepare(text_encoder2) - (text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2]) + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoder1 if train_text_encoder1 else None, + text_encoder2=text_encoder2 if train_text_encoder2 else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_unet: + unet = accelerator.prepare(unet) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) + if train_text_encoder2: + text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -421,6 +443,8 @@ def train(args): # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする @@ -438,7 +462,9 @@ def train(args): accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # accelerator.print( # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" # ) @@ -457,10 +483,17 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + ) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -482,7 +515,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: @@ -503,6 +536,7 @@ def train(args): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) + # unwrap_model is fine for models not wrapped by accelerator encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( args.max_token_length, input_ids1, @@ -512,6 +546,7 @@ def train(args): text_encoder1, text_encoder2, None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, ) else: encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) @@ -533,7 +568,7 @@ def train(args): # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") # get size embeddings orig_size = batch["original_sizes_hw"] @@ -547,7 +582,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -562,9 +597,12 @@ def train(args): or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss + or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -578,7 +616,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -698,7 +736,7 @@ def train(args): accelerator.end_training() - if args.save_state: # and is_main_process: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -720,15 +758,18 @@ def train(args): logit_scale, ckpt_info, ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) @@ -748,7 +789,9 @@ def setup_parser() -> argparse.ArgumentParser: help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", ) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--no_half_vae", @@ -762,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) - return parser @@ -770,6 +812,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697..f89c3628 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -2,7 +2,6 @@ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse -import gc import json import math import os @@ -13,20 +12,17 @@ from types import SimpleNamespace import toml from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util import library.model_util as model_util import library.train_util as train_util @@ -47,6 +43,12 @@ from library.custom_train_functions import ( apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -67,6 +69,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -80,11 +83,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -116,7 +119,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -126,7 +129,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -134,7 +139,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -166,9 +171,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -233,14 +236,14 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(unet.prepare_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -256,7 +259,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -283,9 +288,6 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare (train_network here only) - unet = train_util.transform_models_if_DDP([unet])[0] - if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: @@ -296,8 +298,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -328,8 +329,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -345,6 +348,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -389,15 +394,15 @@ def train(args): with accelerator.accumulate(unet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: @@ -434,7 +439,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -453,14 +458,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -544,22 +549,24 @@ def train(args): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) @@ -572,8 +579,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", @@ -601,6 +612,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10ea..e85e978c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -10,19 +9,16 @@ from types import SimpleNamespace import toml from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util import library.model_util as model_util import library.train_util as train_util @@ -43,6 +39,12 @@ from library.custom_train_functions import ( apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -63,6 +65,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -76,11 +79,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -112,7 +115,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -122,7 +125,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -130,7 +135,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -165,9 +170,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -201,14 +204,14 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(network.prepare_optimizer_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -224,7 +227,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -254,9 +259,6 @@ def train(args): ) network: control_net_lllite.ControlNetLLLite - # transform DDP after prepare (train_network here only) - unet, network = train_util.transform_models_if_DDP([unet, network]) - if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: @@ -269,8 +271,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -301,8 +302,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -358,15 +361,15 @@ def train(args): with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: @@ -403,7 +406,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -423,14 +426,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -521,15 +524,17 @@ def train(args): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) @@ -542,8 +547,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", @@ -571,6 +580,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 199c4e03..83969bb1 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,15 +1,15 @@ import argparse + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from library import sdxl_model_util, sdxl_train_util, train_util import train_network - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): @@ -62,13 +62,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす - print("move vae and unet to cpu to save memory") + logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -83,17 +82,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) if not args.lowram: - print("move vae and unet back to original device") + logger.info("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device) - text_encoders[1].to(accelerator.device) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: @@ -123,6 +121,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[0], text_encoders[1], None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, ) else: encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) @@ -144,7 +143,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") return encoder_hidden_states1, encoder_hidden_states2, pool2 @@ -179,6 +178,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = SdxlNetworkTrainer() diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f5cca17b..5df739e2 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -2,15 +2,11 @@ import argparse import os import regex + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass -import open_clip +from library.device_utils import init_ipex +init_ipex() + from library import sdxl_model_util, sdxl_train_util, train_util import train_textual_inversion @@ -64,6 +60,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine text_encoders[0], text_encoders[1], None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, ) return encoder_hidden_states1, encoder_hidden_states2, pool2 @@ -134,6 +131,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = SdxlTextualInversionTrainer() diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 17916ef7..32101de3 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,9 +16,13 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging, add_logging_arguments +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) # check cache latents arg @@ -41,18 +45,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -63,7 +67,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -90,7 +94,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -98,7 +103,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: @@ -113,8 +118,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: # dataloaderを準備する train_dataset_group.set_caching_mode("latents") - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -152,7 +157,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - print(f"Skipping {image_info.latents_npz} because it already exists.") + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) @@ -167,6 +172,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7d9b13d6..a75d9da7 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,9 +16,13 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging, add_logging_arguments +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) # check cache arg @@ -48,18 +52,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -70,7 +74,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -95,14 +99,15 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) text_encoders = [text_encoder1, text_encoder2] @@ -118,8 +123,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: # dataloaderを準備する train_dataset_group.set_caching_mode("text") - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -147,7 +152,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if os.path.exists(image_info.text_encoder_outputs_npz): - print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") continue image_info.input_ids1 = input_ids1 @@ -168,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) diff --git a/tools/canny.py b/tools/canny.py index 5e080689..f2190975 100644 --- a/tools/canny.py +++ b/tools/canny.py @@ -1,6 +1,10 @@ import argparse import cv2 +import logging +from library.utils import setup_logging +setup_logging() +logger = logging.getLogger(__name__) def canny(args): img = cv2.imread(args.input) @@ -10,7 +14,7 @@ def canny(args): # canny_img = 255 - canny_img cv2.imwrite(args.output, canny_img) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index b9365b51..572ee2f0 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,7 +6,10 @@ import torch from diffusers import StableDiffusionPipeline import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def convert(args): # 引数を確認する @@ -23,21 +26,23 @@ def convert(args): is_load_ckpt = os.path.isfile(args.model_to_load) is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 - assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" + assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" # assert ( # is_save_ckpt or args.reference_model is not None # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" # モデルを読み込む msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") + logger.info(f"loading {msg}: {args.model_to_load}") if is_load_ckpt: v2_model = args.v2 - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( + v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection + ) else: pipe = StableDiffusionPipeline.from_pretrained( - args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None + args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant ) text_encoder = pipe.text_encoder vae = pipe.vae @@ -46,26 +51,37 @@ def convert(args): if args.v1 == args.v2: # 自動判定する v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ("v2" if v2_model else "v1")) + logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) else: v2_model = not args.v1 # 変換して保存する msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") + logger.info(f"converting and saving as {msg}: {args.model_to_save}") if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None key_count = model_util.save_stable_diffusion_checkpoint( - v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae + v2_model, + args.model_to_save, + text_encoder, + unet, + original_model, + args.epoch, + args.global_step, + None if args.metadata is None else eval(args.metadata), + save_dtype=save_dtype, + vae=vae, ) - print(f"model saved. total converted state_dict keys: {key_count}") + logger.info(f"model saved. total converted state_dict keys: {key_count}") else: - print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}") + logger.info( + f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" + ) model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print(f"model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -77,7 +93,9 @@ def setup_parser() -> argparse.ArgumentParser: "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" ) parser.add_argument( - "--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)" + "--unet_use_linear_projection", + action="store_true", + help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)", ) parser.add_argument( "--fp16", @@ -99,6 +117,18 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" ) + parser.add_argument( + "--metadata", + type=str, + default=None, + help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'', + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16", + ) parser.add_argument( "--reference_model", type=str, diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index 68dec6ca..bbc643ed 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,6 +15,10 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) KP_REYE = 11 KP_LEYE = 19 @@ -24,7 +28,7 @@ SCORE_THRES = 0.90 def detect_faces(detector, image, min_size): preds = detector(image) # bgr - # print(len(preds)) + # logger.info(len(preds)) faces = [] for pred in preds: @@ -78,7 +82,7 @@ def process(args): assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" # アニメ顔検出モデルを読み込む - print("loading face detector.") + logger.info("loading face detector.") detector = create_detector('yolov3') # cropの引数を解析する @@ -97,7 +101,7 @@ def process(args): crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] # 画像を処理する - print("processing.") + logger.info("processing.") output_extension = ".png" os.makedirs(args.dst_dir, exist_ok=True) @@ -111,7 +115,7 @@ def process(args): if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) if image.shape[2] == 4: - print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい h, w = image.shape[:2] @@ -144,11 +148,11 @@ def process(args): # 顔サイズを基準にリサイズする scale = args.resize_face_size / face_size if scale < cur_crop_width / w: - print( + logger.warning( f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_width / w if scale < cur_crop_height / h: - print( + logger.warning( f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_height / h elif crop_h_ratio is not None: @@ -157,10 +161,10 @@ def process(args): else: # 切り出しサイズ指定あり if w < cur_crop_width: - print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_width / w if h < cur_crop_height: - print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_height / h if args.resize_fit: scale = max(cur_crop_width / w, cur_crop_height / h) @@ -198,7 +202,7 @@ def process(args): face_img = face_img[y:y + cur_crop_height] # # debug - # print(path, cx, cy, angle) + # logger.info(path, cx, cy, angle) # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) # cv2.imshow("image", crp) # if cv2.waitKey() == 27: diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa339..f05cf719 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -11,10 +11,16 @@ from typing import Dict, List import numpy as np import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torch import nn from tqdm import tqdm from PIL import Image - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -216,7 +222,7 @@ class Upscaler(nn.Module): upsampled_images = upsampled_images / 127.5 - 1.0 # convert upsample images to latents with batch size - # print("Encoding upsampled (LANCZOS4) images...") + # logger.info("Encoding upsampled (LANCZOS4) images...") upsampled_latents = [] for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): batch = upsampled_images[i : i + vae_batch_size].to(vae.device) @@ -227,7 +233,7 @@ class Upscaler(nn.Module): upsampled_latents = torch.cat(upsampled_latents, dim=0) # upscale (refine) latents with this model with batch size - print("Upscaling latents...") + logger.info("Upscaling latents...") upscaled_latents = [] for i in range(0, upsampled_latents.shape[0], batch_size): with torch.no_grad(): @@ -242,7 +248,7 @@ def create_upscaler(**kwargs): weights = kwargs["weights"] model = Upscaler() - print(f"Loading weights from {weights}...") + logger.info(f"Loading weights from {weights}...") if os.path.splitext(weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -255,20 +261,20 @@ def create_upscaler(**kwargs): # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEVICE = get_preferred_device() us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True) # load VAE with Diffusers assert args.vae_path is not None, "VAE path is required" - print(f"Loading VAE from {args.vae_path}...") + logger.info(f"Loading VAE from {args.vae_path}...") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.to(DEVICE, dtype=us_dtype) # prepare model - print("Preparing model...") + logger.info("Preparing model...") upscaler: Upscaler = create_upscaler(weights=args.weights) - # print("Loading weights from", args.weights) + # logger.info("Loading weights from", args.weights) # upscaler.load_state_dict(torch.load(args.weights)) upscaler.eval() upscaler.to(DEVICE, dtype=us_dtype) @@ -303,14 +309,14 @@ def upscale_images(args: argparse.Namespace): image_debug.save(dest_file_name) # upscale - print("Upscaling...") + logger.info("Upscaling...") upscaled_latents = upscaler.upscale( vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size ) upscaled_latents /= 0.18215 # decode with batch - print("Decoding...") + logger.info("Decoding...") upscaled_images = [] for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): with torch.no_grad(): diff --git a/tools/merge_models.py b/tools/merge_models.py index 391bfe67..8f1fbf2f 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -5,7 +5,10 @@ import torch from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def is_unet_key(key): # VAE or TextEncoder, the last one is for SDXL @@ -45,10 +48,10 @@ def merge(args): # check if all models are safetensors for model in args.models: if not model.endswith("safetensors"): - print(f"Model {model} is not a safetensors model") + logger.info(f"Model {model} is not a safetensors model") exit() if not os.path.isfile(model): - print(f"Model {model} does not exist") + logger.info(f"Model {model} does not exist") exit() assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" @@ -65,7 +68,7 @@ def merge(args): if merged_sd is None: # load first model - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") merged_sd = {} with safe_open(model, framework="pt", device=args.device) as f: for key in tqdm(f.keys()): @@ -81,11 +84,11 @@ def merge(args): value = ratio * value.to(dtype) # first model's value * ratio merged_sd[key] = value - print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) continue # load other models - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") with safe_open(model, framework="pt", device=args.device) as f: model_keys = f.keys() @@ -93,7 +96,7 @@ def merge(args): _, new_key = replace_text_encoder_key(key) if new_key not in merged_sd: if args.show_skipped and new_key not in first_model_keys: - print(f"Skip: {new_key}") + logger.info(f"Skip: {new_key}") continue value = f.get_tensor(key) @@ -104,7 +107,7 @@ def merge(args): for key in merged_sd.keys(): if key in model_keys: continue - print(f"Key {key} not in model {model}, use first model's value") + logger.warning(f"Key {key} not in model {model}, use first model's value") if key in supplementary_key_ratios: supplementary_key_ratios[key] += ratio else: @@ -112,7 +115,7 @@ def merge(args): # add supplementary keys' value (including VAE and TextEncoder) if len(supplementary_key_ratios) > 0: - print("add first model's value") + logger.info("add first model's value") with safe_open(args.models[0], framework="pt", device=args.device) as f: for key in tqdm(f.keys()): _, new_key = replace_text_encoder_key(key) @@ -120,7 +123,7 @@ def merge(args): continue if is_unet_key(new_key): # not VAE or TextEncoder - print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") value = f.get_tensor(key) # original key @@ -134,7 +137,7 @@ def merge(args): if not output_file.endswith(".safetensors"): output_file = output_file + ".safetensors" - print(f"Saving to {output_file}...") + logger.info(f"Saving to {output_file}...") # convert to save_dtype for k in merged_sd.keys(): @@ -142,7 +145,7 @@ def merge(args): save_file(merged_sd, output_file) - print("Done!") + logger.info("Done!") if __name__ == "__main__": diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76..5640d542 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -7,7 +7,10 @@ from safetensors.torch import load_file from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ControlNetInfo(NamedTuple): unet: Any @@ -51,7 +54,7 @@ def load_control_net(v2, unet, model): # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") + logger.info(f"ControlNet: loading control SD model : {model}") if model_util.is_safetensors(model): ctrl_sd_sd = load_file(model) @@ -61,7 +64,7 @@ def load_control_net(v2, unet, model): # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) + logger.info(f"ControlNet: loading difference: {is_difference}") # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # またTransfer Controlの元weightとなる @@ -89,13 +92,13 @@ def load_control_net(v2, unet, model): # ControlNetのU-Netを作成する ctrl_unet = UNet2DConditionModel(**unet_config) info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) + logger.info(f"ControlNet: loading Control U-Net: {info}") # U-Net以外のControlNetを作成する # TODO support middle only ctrl_net = ControlNet() info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) + logger.info("ControlNet: loading ControlNet: {info}") ctrl_unet.to(unet.device, dtype=unet.dtype) ctrl_net.to(unet.device, dtype=unet.dtype) @@ -117,7 +120,7 @@ def load_preprocess(prep_type: str): return canny - print("Unsupported prep type:", prep_type) + logger.info(f"Unsupported prep type: {prep_type}") return None @@ -174,13 +177,26 @@ def call_unet_and_control_net( cnet_idx = step % cnet_cnt cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: return original_unet(sample, timestep, encoder_hidden_states) guided_hint = guided_hints[cnet_idx] + + # gradual latent support: match the size of guided_hint to the size of sample + if guided_hint.shape[-2:] != sample.shape[-2:]: + # print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}") + org_dtype = guided_hint.dtype + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(torch.float32) + guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic") + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(org_dtype) + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net) + outs = unet_forward( + True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net + ) outs = [o * cnet_info.weight for o in outs] # U-Net @@ -192,7 +208,7 @@ def call_unet_and_control_net( # ControlNet cnet_outs_list = [] for i, cnet_info in enumerate(control_nets): - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: continue guided_hint = guided_hints[i] @@ -232,7 +248,7 @@ def unet_forward( upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") + logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 1. time diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 2d3224c4..b8069fc1 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,10 @@ import shutil import math from PIL import Image import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): # Split the max_resolution string by "," and strip any whitespaces @@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi image.save(os.path.join(dst_img_folder, new_filename), quality=100) proc = "Resized" if current_pixels > max_pixels else "Saved" - print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") # If other files with same basename, copy them with resolution suffix if copy_associated_files: @@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi continue for max_resolution in max_resolutions: new_asoc_file = base + '+' + max_resolution + ext - print(f"Copy {asoc_file} as {new_asoc_file}") + logger.info(f"Copy {asoc_file} as {new_asoc_file}") shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) diff --git a/tools/show_metadata.py b/tools/show_metadata.py index 92ca7b1c..05bfbe0a 100644 --- a/tools/show_metadata.py +++ b/tools/show_metadata.py @@ -1,6 +1,10 @@ import json import argparse from safetensors import safe_open +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) @@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f: metadata = f.metadata() if metadata is None: - print("No metadata found") + logger.error("No metadata found") else: # metadata is json dict, but not pretty printed # sort by key and pretty print print(json.dumps(metadata, indent=4, sort_keys=True)) - \ No newline at end of file + diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb..f4c94e8d 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -10,14 +9,12 @@ from types import SimpleNamespace import toml from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library import deepspeed_utils +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel @@ -37,6 +34,12 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -58,6 +61,7 @@ def train(args): # training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -71,11 +75,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -105,7 +109,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -116,7 +120,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -221,10 +225,8 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -238,8 +240,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -255,7 +257,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -311,8 +315,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -333,9 +339,13 @@ def train(args): ) if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) loss_recorder = train_util.LossRecorder() del train_dataset_group @@ -371,6 +381,11 @@ def train(args): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet + ) + # training loop for epoch in range(num_train_epochs): if is_main_process: @@ -382,7 +397,7 @@ def train(args): with accelerator.accumulate(controlnet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -405,13 +420,8 @@ def train(args): ) # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device, - ) - timesteps = timesteps.long() + timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device) + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -442,14 +452,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -551,7 +561,7 @@ def train(args): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく @@ -560,15 +570,17 @@ def train(args): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, controlnet, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) @@ -600,6 +612,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_db.py b/train_db.py index 4ff19645..e9f8f9f8 100644 --- a/train_db.py +++ b/train_db.py @@ -1,7 +1,6 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc import argparse import itertools import math @@ -10,17 +9,14 @@ from multiprocessing import Value import toml from tqdm import tqdm + import torch +from library import deepspeed_utils +from library.device_utils import init_ipex, clean_memory_on_device -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler @@ -39,7 +35,14 @@ from library.custom_train_functions import ( apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # perlin_noise, @@ -47,6 +50,8 @@ from library.custom_train_functions import ( def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -57,13 +62,13 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -98,13 +103,13 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") if args.gradient_accumulation_steps > 1: - print( + logger.warning( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" ) - print( + logger.warning( f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) @@ -112,6 +117,7 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) @@ -136,15 +142,13 @@ def train(args): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -181,8 +185,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -197,7 +201,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -218,15 +224,25 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + training_models = [ds_model] - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) + else: + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + training_models = [unet, text_encoder] + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + training_models = [unet] if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -270,10 +286,15 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -293,12 +314,14 @@ def train(args): if not args.gradient_checkpointing: text_encoder.train(False) text_encoder.requires_grad_(False) + if len(training_models) == 2: + training_models = training_models[0] # remove text_encoder from training_models - with accelerator.accumulate(unet): + with accelerator.accumulate(*training_models): with torch.no_grad(): # latentに変換 if cache_latents: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -323,7 +346,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -335,14 +358,16 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: @@ -441,7 +466,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -451,15 +476,18 @@ def train(args): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) @@ -482,6 +510,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser @@ -490,6 +523,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_network.py b/train_network.py index d50916b7..c99d3724 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import sys @@ -11,25 +10,18 @@ from multiprocessing import Value import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory_on_device -try: - import intel_extension_for_pytorch as ipex +init_ipex() - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import model_util +from library import deepspeed_utils, model_util import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) +from library.train_util import DreamBoothDataset import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -44,7 +36,14 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) class NetworkTrainer: @@ -116,7 +115,7 @@ class NetworkTrainer: self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): for t_enc in text_encoders: - t_enc.to(accelerator.device) + t_enc.to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): input_ids = batch["input_ids"].to(accelerator.device) @@ -127,6 +126,11 @@ class NetworkTrainer: noise_pred = unet(noisy_latents, timesteps, text_conds).sample return noise_pred + def all_reduce_network(self, accelerator, network): + for param in network.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) @@ -135,6 +139,8 @@ class NetworkTrainer: training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -150,20 +156,20 @@ class NetworkTrainer: # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -174,7 +180,7 @@ class NetworkTrainer: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -203,7 +209,7 @@ class NetworkTrainer: train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -216,7 +222,7 @@ class NetworkTrainer: self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する - print("preparing accelerator") + logger.info("preparing accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -265,13 +271,12 @@ class NetworkTrainer: with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu self.cache_text_encoder_outputs_if_needed( args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype ) @@ -303,11 +308,12 @@ class NetworkTrainer: ) if network is None: return + network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( + logger.warning( "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" ) args.scale_weight_norms = False @@ -342,8 +348,8 @@ class NetworkTrainer: optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -383,53 +389,59 @@ class NetworkTrainer: accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != "no" + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) - # acceleratorがなんかよろしくやってくれるらしい - # TODO めちゃくちゃ冗長なのでコードを整理する - if train_unet and train_text_encoder: - if len(text_encoders) > 1: - unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler - ) - text_encoder = text_encoders = [t_enc1, t_enc2] - del t_enc1, t_enc2 - else: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - text_encoders = [text_encoder] - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) - elif train_text_encoder: - if len(text_encoders) > 1: - t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler - ) - text_encoder = text_encoders = [t_enc1, t_enc2] - del t_enc1, t_enc2 - else: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - text_encoders = [text_encoder] + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoders[0] if train_text_encoder else None, + text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + network=network, + ) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_model = ds_model else: + if train_unet: + unet = accelerator.prepare(unet) + else: + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: + if len(text_encoders) > 1: + text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + else: + text_encoder = accelerator.prepare(text_encoder) + text_encoders = [text_encoder] + else: + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( network, optimizer, train_dataloader, lr_scheduler ) - - # transform DDP after prepare (train_network here only) - text_encoders = train_util.transform_models_if_DDP(text_encoders) - unet, network = train_util.transform_models_if_DDP([unet, network]) + training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required @@ -441,9 +453,6 @@ class NetworkTrainer: if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -451,7 +460,7 @@ class NetworkTrainer: del t_enc - network.prepare_grad_etc(text_encoder, unet) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) @@ -462,6 +471,31 @@ class NetworkTrainer: if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) + # before resuming make hook for saving/loading to save/load the network weights only + def save_model_hook(models, weights, output_dir): + # pop weights of other models than network to save only network weights + if accelerator.is_main_process: + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + weights.pop(i) + # print(f"save model hook: {len(weights)} weights will be saved") + + def load_model_hook(models, input_dir): + # remove models except network + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + models.pop(i) + # print(f"load model hook: {len(models)} models will be loaded") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) @@ -535,6 +569,11 @@ class NetworkTrainer: "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, "ss_debiased_estimation": bool(args.debiased_estimation_loss), + "ss_noise_offset_random_strength": args.noise_offset_random_strength, + "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, + "ss_loss_type": args.loss_type, + "ss_huber_schedule": args.huber_schedule, + "ss_huber_c": args.huber_c, } if use_user_config: @@ -570,6 +609,11 @@ class NetworkTrainer: "random_crop": bool(subset.random_crop), "shuffle_caption": bool(subset.shuffle_caption), "keep_tokens": subset.keep_tokens, + "keep_tokens_separator": subset.keep_tokens_separator, + "secondary_separator": subset.secondary_separator, + "enable_wildcard": bool(subset.enable_wildcard), + "caption_prefix": subset.caption_prefix, + "caption_suffix": subset.caption_suffix, } image_dir_or_metadata_file = None @@ -704,6 +748,8 @@ class NetworkTrainer: if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -714,8 +760,8 @@ class NetworkTrainer: del train_dataset_group # callback for step start - if hasattr(network, "on_step_start"): - on_step_start = network.on_step_start + if hasattr(accelerator.unwrap_model(network), "on_step_start"): + on_step_start = accelerator.unwrap_model(network).on_step_start else: on_step_start = lambda *args, **kwargs: None @@ -743,6 +789,9 @@ class NetworkTrainer: accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -750,26 +799,36 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch + 1) - network.on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(network): + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * self.vae_scale_factor + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + accelerator.unwrap_model(network).set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning @@ -789,14 +848,28 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: @@ -805,14 +878,18 @@ class NetworkTrainer: else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -823,16 +900,18 @@ class NetworkTrainer: loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} @@ -910,27 +989,32 @@ class NetworkTrainer: accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) parser.add_argument( "--save_model_as", type=str, @@ -942,10 +1026,17 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") parser.add_argument( - "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", ) parser.add_argument( "--network_alpha", @@ -960,14 +1051,25 @@ def setup_parser() -> argparse.ArgumentParser: help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", ) parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" - ) - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") - parser.add_argument( - "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) parser.add_argument( - "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" + ) + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", ) parser.add_argument( "--dim_from_weights", @@ -1006,6 +1108,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = NetworkTrainer() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index faa2ff61..cbb6daaa 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,23 +1,21 @@ import argparse -import gc import math import os from multiprocessing import Value import toml from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import model_util +from library import deepspeed_utils, model_util import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -33,7 +31,14 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -170,6 +175,7 @@ class TextualInversionTrainer: train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -180,7 +186,7 @@ class TextualInversionTrainer: tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -265,7 +271,7 @@ class TextualInversionTrainer: # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -290,7 +296,7 @@ class TextualInversionTrainer: ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -365,9 +371,7 @@ class TextualInversionTrainer: with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -384,8 +388,8 @@ class TextualInversionTrainer: _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -415,15 +419,11 @@ class TextualInversionTrainer: text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder_or_list, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet) elif len(text_encoders) == 2: text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] @@ -442,9 +442,10 @@ class TextualInversionTrainer: # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) + unwrapped_text_encoder.text_model.encoder.requires_grad_(False) + unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False) + unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) unet.requires_grad_(False) @@ -504,6 +505,8 @@ class TextualInversionTrainer: if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -529,6 +532,20 @@ class TextualInversionTrainer: accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + self.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -544,10 +561,10 @@ class TextualInversionTrainer: with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) latents = latents * self.vae_scale_factor # Get the text embedding for conditioning @@ -555,7 +572,7 @@ class TextualInversionTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -571,14 +588,16 @@ class TextualInversionTrainer: else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -590,7 +609,7 @@ class TextualInversionTrainer: accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() + params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -602,9 +621,11 @@ class TextualInversionTrainer: for text_encoder, orig_embeds_params, index_no_updates in zip( text_encoders, orig_embeds_params_list, index_no_updates_list ): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 + input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight + input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ index_no_updates - ] = orig_embeds_params[index_no_updates] + ] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -712,27 +733,29 @@ class TextualInversionTrainer: is_main_process = accelerator.is_main_process if is_main_process: text_encoder = accelerator.unwrap_model(text_encoder) + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False) @@ -745,7 +768,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -755,7 +780,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", @@ -779,6 +806,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = TextualInversionTrainer() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 66474ce7..55b67035 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -1,20 +1,18 @@ import importlib import argparse -import gc import math import os import toml from multiprocessing import Value from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library import deepspeed_utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler @@ -35,9 +33,16 @@ from library.custom_train_functions import ( apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -96,12 +101,13 @@ def train(args): if args.output_name is None: args.output_name = args.token_string use_template = args.use_object_template or args.use_style_template + setup_logging(args, reset=True) train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: - print( + logger.warning( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) assert ( @@ -116,7 +122,7 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -129,7 +135,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + logger.warning( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -143,7 +149,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + logger.info(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -171,7 +177,7 @@ def train(args): tokenizer.add_tokens(token_strings_XTI) token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - print(f"tokens are added (XTI): {token_ids_XTI}") + logger.info(f"tokens are added (XTI): {token_ids_XTI}") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -180,7 +186,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids_XTI): token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -188,22 +194,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # logger.info(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids_XTI, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + logger.info(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.info( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -211,14 +217,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + logger.info("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -242,7 +248,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print(f"use template for training captions. is object: {args.use_object_template}") + logger.info(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -266,7 +272,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -288,9 +294,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -299,13 +303,13 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + logger.info("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -320,7 +324,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -333,11 +339,8 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # logger.info(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -375,15 +378,17 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + logger.info("running training / 学習開始") + logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + logger.info(f" num epochs / epoch数: {num_train_epochs}") + logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") + logger.info( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + logger.info(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -397,16 +402,21 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -414,12 +424,13 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + logger.info(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + logger.info("") + logger.info(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() @@ -431,7 +442,7 @@ def train(args): with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -450,7 +461,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -462,14 +473,16 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: @@ -578,7 +591,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() @@ -589,7 +602,7 @@ def train(args): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def save_weights(file, updated_embs, save_dtype): @@ -650,9 +663,12 @@ def load_weights(file): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False) @@ -665,7 +681,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -675,7 +693,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", @@ -694,6 +714,7 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args)