Compare commits

...

82 Commits

Author SHA1 Message Date
Kohya S.
966e9d7f6b Merge pull request #2254 from kohya-ss/dev
Merge the changes from the sd3 branch into main
2026-01-19 22:00:25 +09:00
Kohya S.
2a2760e702 Merge pull request #1374 from kohya-ss/sd3
support SD3
2026-01-19 21:50:22 +09:00
Kohya S.
b996440c5f Doc update sd3 branch documentation (#2253)
* doc: move sample prompt file documentation, and remove history for branch

* doc: remove outdated FLUX.1 and SD3 training information from README

* doc: update README and training documentation for clarity and structure
2026-01-19 21:38:46 +09:00
Kohya S.
a9af52692a feat: add pyramid noise and noise offset options to generation script (#2252)
* feat: add pyramid noise and noise offset options to generation script

* fix: fix to work with SD1.5 models

* doc: update to match with latest gen_img.py

* doc: update README to clarify script capabilities and remove deprecated sections
2026-01-18 16:56:48 +09:00
Kohya S.
c6bc632ec6 fix: metadata dataset degradation and make it work (#2186)
* fix: support dataset with metadata

* feat: support another tagger model

* fix: improve handling of image size and caption/tag processing in FineTuningDataset

* fix: enhance metadata loading to support JSONL format in FineTuningDataset

* feat: enhance image loading and processing in ImageLoadingPrepDataset with batch support and output options

* fix: improve image path handling and memory management in dataset classes

* Update finetune/tag_images_by_wd14_tagger.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: add return type annotation for process_tag_replacement function and ensure tags are returned

* feat: add artist category threshold for tagging

* doc: add comment for clarification

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-18 15:17:07 +09:00
Kohya S.
f7f971f50d Merge pull request #2251 from kohya-ss/fix-pytest-for-lumina
fix(tests): add ip_noise_gamma args for MockArgs in pytest
2026-01-18 15:09:47 +09:00
Kohya S
c4be615f69 fix(tests): add ip_noise_gamma args for MockArgs in pytest 2026-01-18 15:05:57 +09:00
Kohya S.
e06e063970 Merge pull request #2225 from urlesistiana/sd3_lumina2_ts_fix
fix: lumina 2 timesteps handling
2026-01-18 14:39:04 +09:00
Kohya S.
94e3dbebea Merge pull request #2246 from kozistr/deps/pytorch-optimizer
Bump `pytorch-optimizer` version to v3.9.0
2025-12-21 22:51:32 +09:00
kozistr
95a65b89a5 build(deps): bump pytorch-optimizer to v3.9.0 2025-12-21 15:53:47 +09:00
Kohya S.
a5a162044c Merge pull request #2226 from kohya-ss/fix-hunyuan-image-batch-gen-error
fix: error on batch generation closes #2209
2025-10-15 21:57:45 +09:00
Kohya S
a33cad714e fix: error on batch generation closes #2209 2025-10-15 21:57:11 +09:00
urlesistiana
f7fc7ddda2 fix #2201: lumina 2 timesteps handling 2025-10-13 16:08:28 +08:00
Kohya S.
5e366acda4 Merge pull request #2003 from laolongboy/sd3-dev
Fix missing parameters in model conversion script
2025-10-01 21:03:12 +09:00
Kohya S
5462a6bb24 Merge branch 'dev' into sd3 2025-09-29 21:02:02 +09:00
Kohya S
63711390a0 Merge branch 'main' into dev 2025-09-29 20:56:07 +09:00
Kohya S.
206adb6438 Merge pull request #2216 from kohya-ss/fix-sdxl-textual-inversion-training-disable-mmap
fix: disable_mmap_safetensors not defined in SDXL TI training
2025-09-29 20:55:02 +09:00
Kohya S
60bfa97b19 fix: disable_mmap_safetensors not defined in SDXL TI training 2025-09-29 20:52:48 +09:00
Kohya S.
f0c767e0f2 Merge pull request #2213 from kohya-ss/doc-hunyuan-image-training-text-encoder-cpu-note
docs: enhance text encoder CPU usage instructions for HunyuanImage-2.…
2025-09-28 18:32:11 +09:00
kohya-ss
a0c26a0efa docs: enhance text encoder CPU usage instructions for HunyuanImage-2.1 training 2025-09-28 18:21:25 +09:00
Kohya S.
67d0621313 Merge pull request #2212 from kohya-ss/fix-hunyuan-image-sample-generation
fix: HunyuanImage-2.1 sample generation fails
2025-09-28 18:12:04 +09:00
Kohya S
6a826d21b1 feat: add new parameters for sample image inference configuration 2025-09-28 18:06:17 +09:00
Kohya S.
4c197a538b Merge pull request #2207 from kohya-ss/fix-flux-extract-lora-metadata-failed
fix: update metadata construction to include model_config for flux
2025-09-24 21:19:27 +09:00
Kohya S
4b79d73504 fix: update metadata construction to include model_config for flux 2025-09-24 21:15:37 +09:00
Kohya S.
121853ca2a Merge pull request #2198 from kohya-ss/feat-hunyuan-image-2.1-inference
feat: support HunyuanImage-2.1
2025-09-23 19:11:50 +09:00
Kohya S
58df9dffa4 doc: update README with HunyuanImage-2.1 LoRA training details and requirements 2025-09-23 18:59:02 +09:00
Kohya S
31f7df3b3a doc: add --network_train_unet_only option for HunyuanImage-2.1 training 2025-09-23 18:53:36 +09:00
Kohya S.
753c794549 Update hunyuan_image_train_network.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-21 13:30:22 +09:00
Kohya S.
e7b89826c5 Update library/custom_offloading_utils.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-21 13:29:58 +09:00
Kohya S
806d535ef1 fix: block-wise scaling is overwritten by per-tensor scaling 2025-09-21 13:10:41 +09:00
Kohya S
3876343fad fix: remove print statement for guidance rescale in AdaptiveProjectedGuidance 2025-09-21 13:09:38 +09:00
Kohya S
040d976597 feat: add guidance rescale options for Adaptive Projected Guidance in inference 2025-09-21 13:03:14 +09:00
Kohya S
9621d9d637 feat: add Adaptive Projected Guidance parameters and noise rescaling 2025-09-21 12:34:40 +09:00
Kohya S
e7b8e9a778 doc: add --vae_chunk_size option for training and inference 2025-09-21 11:13:26 +09:00
Kohya S
f41e9e2b58 feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing 2025-09-21 11:09:37 +09:00
Kohya S
8f20c37949 feat: add --text_encoder_cpu option to reduce VRAM usage by running text encoders on CPU for training 2025-09-20 20:26:20 +09:00
Kohya S
b090d15f7d feat: add multi backend attention and related update for HI2.1 models and scripts 2025-09-20 19:45:33 +09:00
Kohya S
f834b2e0d4 fix: --fp8_vl to work 2025-09-18 23:46:18 +09:00
Kohya S
f6b4bdc83f feat: block-wise fp8 quantization 2025-09-18 21:20:54 +09:00
Kohya S
2ce506e187 fix: fp8 casting not working 2025-09-18 21:20:08 +09:00
Kohya S
f5b004009e fix: correct tensor indexing in HunyuanVAE2D class for blending and encoding functions 2025-09-17 21:54:25 +09:00
Kohya S
cbe2a9da45 feat: add conversion script for LoRA models to ComfyUI format with reverse option 2025-09-16 21:48:47 +09:00
kohya-ss
f318ddaeea docs: update HunyuanImage-2.1 training guide with model download instructions and VRAM optimization settings (by Claude) 2025-09-16 21:18:01 +09:00
kohya-ss
39458ec0e3 fix: update default values for guidance_scale, image_size, infer_steps, and flow_shift in argument parser 2025-09-16 21:17:21 +09:00
Kohya S
2732be0b29 Merge branch 'feat-hunyuan-image-2.1-inference' of https://github.com/kohya-ss/sd-scripts into feat-hunyuan-image-2.1-inference 2025-09-14 20:49:24 +09:00
Kohya S
1a73b5e8a5 feat: add script to convert LoRA format to ComfyUI format 2025-09-14 20:49:20 +09:00
kohya-ss
e04b9f0497 docs: add LoRA training guide for HunyuanImage-2.1 model (by Gemini CLI) 2025-09-13 22:06:10 +09:00
Kohya S
29b0500e70 fix: restore files section in _typos.toml for exclusion configuration 2025-09-13 21:18:50 +09:00
Kohya S
4e2a80a6ca refactor: update imports to use safetensors_utils for memory-efficient operations 2025-09-13 21:07:11 +09:00
Kohya S
d831c88832 fix: sample generation doesn't work with block swap 2025-09-13 21:06:04 +09:00
Kohya S
bae7fa74eb Merge branch 'sd3' into feat-hunyuan-image-2.1-inference 2025-09-13 20:13:58 +09:00
Kohya S.
f5d44fd487 Merge pull request #2200 from kohya-ss/feat-faster-safetensors-load
feat: Speeding up loading .safetensors files
2025-09-13 20:09:03 +09:00
Kohya S
4568631b43 docs: update README to reflect improved loading speed of .safetensors files 2025-09-13 20:05:39 +09:00
Kohya S.
e1c666e97f Update library/safetensors_utils.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-13 20:03:55 +09:00
Kohya S
8783f8aed3 feat: faster safetensors load and split safetensor utils 2025-09-13 19:51:38 +09:00
Kohya S
9a61d61b22 feat: avoid unet type casting when fp8_scaled 2025-09-12 22:18:29 +09:00
Kohya S
7a651efd4d feat: add 'tak' to recognized words and update block swap method to support backward pass 2025-09-12 22:00:41 +09:00
Kohya S
aa0af24d01 Merge branch 'sd3' into feat-hunyuan-image-2.1-inference 2025-09-12 21:41:12 +09:00
Kohya S
209c02dbb6 feat: HunyuanImage LoRA training 2025-09-12 21:40:42 +09:00
Kohya S.
419a9c4af4 Merge pull request #2192 from kohya-ss/doc-update-for-latest-features
Doc update for latest features
2025-09-12 20:28:42 +09:00
Kohya S
cbc9e1a3b1 feat: add byt5 to the list of recognized words in typos configuration 2025-09-11 22:27:08 +09:00
Kohya S
a0f0afbb46 fix: revert constructor signature update 2025-09-11 22:27:00 +09:00
Kohya S
7f983c558d feat: block swap for inference and initial impl for HunyuanImage LoRA (not working) 2025-09-11 22:15:22 +09:00
Kohya S
5149be5a87 feat: initial commit for HunyuanImage-2.1 inference 2025-09-11 12:54:12 +09:00
kohya-ss
ee8e670765 Merge branch 'sd3' into doc-update-for-latest-features 2025-09-09 12:42:09 +09:00
Kohya S.
f8337726cf Merge pull request #2196 from rockerBOO/validation-dataset-subset
Fix validation dataset documentation to not use subsets
2025-09-09 12:38:23 +09:00
rockerBOO
fe4c18934c blocks_to_swap is supported for validation loss now 2025-09-08 14:28:55 -04:00
rockerBOO
78685b9c5f Move general settings to top to make more clear the validation bits 2025-09-08 14:18:50 -04:00
rockerBOO
ef4397963b Fix validation dataset documentation to not use subsets 2025-09-08 14:16:33 -04:00
kohya-ss
0bb0d91615 doc: update introduction and clarify command line option priorities in config README 2025-09-06 19:52:54 +09:00
Kohya S.
952f9ce7be Update docs/train_textual_inversion.md
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-04 19:46:04 +09:00
kohya-ss
884fc8c7f5 doc: remove SD3/FLUX.1 training guide 2025-09-04 18:40:21 +09:00
kohya-ss
ddfb38e501 doc: add documentation for Textual Inversion training scripts 2025-09-04 18:39:52 +09:00
kohya-ss
6c82327dc8 doc: remove Japanese section on Gradual Latent options from gen_img README 2025-09-01 21:32:50 +09:00
kohya-ss
9984868154 doc: update README to include support for SDXL models and additional command-line options for gen_img.py 2025-09-01 21:32:24 +09:00
kohya-ss
142d0be180 doc: add comprehensive fine-tuning guide for various model architectures 2025-09-01 12:36:51 +09:00
kohya-ss
c38b07d0da doc: add validation loss documentation for model training 2025-08-31 21:39:47 +09:00
kohya-ss
80710134d5 doc: add Sage Attention and sample batch size options to Lumina training guide 2025-08-31 21:19:28 +09:00
kohya-ss
fe81d40202 doc: refactor structure for improved readability and maintainability 2025-08-31 21:14:45 +09:00
kohya-ss
989448afdd doc: enhance SD3/SDXL LoRA training guide 2025-08-31 19:19:10 +09:00
Kohya S.
884e07d73e Merge pull request #2191 from kohya-ss/fix-chroma-training-withtout-te-cache
fix: chroma LoRA training without Text Encode caching
2025-08-30 09:31:11 +09:00
laolongboy
e64dc05c2a Supplement the input parameters to correctly convert the flux model to BFL format; fixes #1996 2025-03-24 23:33:25 +08:00
61 changed files with 10963 additions and 2497 deletions

View File

@@ -1,21 +1,40 @@
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
# sd-scripts
[README in English](./README.md) ←更新情報はこちらにあります
[English](./README.md) / [日本語](./README-ja.md)
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
## 目次
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
<details>
<summary>クリックすると展開します</summary>
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
- [はじめに](#はじめに)
- [スポンサー](#スポンサー)
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
- [更新履歴](#更新履歴)
- [サポートモデル](#サポートモデル)
- [機能](#機能)
- [ドキュメント](#ドキュメント)
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
- [その他のドキュメント](#その他のドキュメント)
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
- [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ)
- [Windows環境でのインストール](#windows環境でのインストール)
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
- [インストール手順](#インストール手順)
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
- [xformersのインストールオプション](#xformersのインストールオプション)
- [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール)
- [DeepSpeedのインストール実験的、LinuxまたはWSL2のみ](#deepspeedのインストール実験的linuxまたはwsl2のみ)
- [アップグレード](#アップグレード)
- [PyTorchのアップグレード](#pytorchのアップグレード)
- [謝意](#謝意)
- [ライセンス](#ライセンス)
以下のスクリプトがあります。
</details>
* DreamBooth、U-NetおよびText Encoderの学習をサポート
* fine-tuning、同上
* LoRAの学習をサポート
* 画像生成
* モデル変換Stable Diffision ckpt/safetensorsとDiffusersの相互変換
## はじめに
Stable Diffusion等の画像生成モデルの学習、モデルによる画像生成、その他のスクリプトを入れたリポジトリです。
### スポンサー
@@ -29,26 +48,117 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
このプロジェクトがお役に立ったなら、ご支援いただけると嬉しく思います。 [GitHub Sponsors](https://github.com/sponsors/kohya-ss/)で受け付けています。
## 使用法について
### 更新履歴
- **Version 0.10.0 (2026-01-19):**
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
### サポートモデル
* **Stable Diffusion 1.x/2.x**
* **SDXL**
* **SD3/SD3.5**
* **FLUX.1**
* **LUMINA**
* **HunyuanImage-2.1**
### 機能
* LoRA学習
* fine-tuningDreamBoothHunyuanImage-2.1以外のモデル
* Textual Inversion学習SD/SDXL
* 画像生成
* その他、モデル変換やタグ付け、LoRAマージなどのユーティリティ
## ドキュメント
### 学習ドキュメント(英語および日本語)
日本語は折りたたまれているか、別のドキュメントにあります。
* [LoRA学習の概要](./docs/train_network.md)
* [データセット設定](./docs/config_README-ja.md) / [英語版](./docs/config_README-en.md)
* [高度な学習オプション](./docs/train_network_advanced.md)
* [SDXL学習](./docs/sdxl_train_network.md)
* [SD3学習](./docs/sd3_train_network.md)
* [FLUX.1学習](./docs/flux_train_network.md)
* [LUMINA学習](./docs/lumina_train_network.md)
* [HunyuanImage-2.1学習](./docs/hunyuan_image_train_network.md)
* [Fine-tuning](./docs/fine_tune.md)
* [Textual Inversion学習](./docs/train_textual_inversion.md)
* [ControlNet-LLLite学習](./docs/train_lllite_README-ja.md) / [英語版](./docs/train_lllite_README.md)
* [Validation](./docs/validation.md)
* [マスク損失学習](./docs/masked_loss_README-ja.md) / [英語版](./docs/masked_loss_README.md)
### その他のドキュメント
* [画像生成スクリプト](./docs/gen_img_README-ja.md) / [英語版](./docs/gen_img_README.md)
* [WD14 Taggerによる画像タグ付け](./docs/wd14_tagger_README-ja.md) / [英語版](./docs/wd14_tagger_README-en.md)
### 旧ドキュメント(日本語)
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
* [LoRAの学習について](./docs/train_network_README-ja.md)
* [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
* [画像生成スクリプト](./docs/gen_img_README-ja.md)
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
## Windowsでの動作に必要なプログラム
## AIコーディングエージェントを使う開発者の方へ
Python 3.10.6およびGitが必要です。
This repository provides recommended instructions to help AI agents like Claude and Gemini understand our project context and coding standards.
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
- git: https://git-scm.com/download/win
To use them, you need to opt-in by creating your own configuration file in the project root.
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
**Quick Setup:**
1. Create a `CLAUDE.md` and/or `GEMINI.md` file in the project root.
2. Add the following line to your `CLAUDE.md` to import the repository's recommended prompt:
```markdown
@./.ai/claude.prompt.md
```
or for Gemini:
```markdown
@./.ai/gemini.prompt.md
```
3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`).
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so they won't be committed to the repository.
このリポジトリでは、AIコーディングエージェントClaude、Geminiなどがプロジェクトのコンテキストやコーディング標準を理解できるようにするための推奨プロンプトを提供しています。
それらを使用するには、プロジェクトディレクトリに設定ファイルを作成して明示的に有効にする必要があります。
**簡単なセットアップ手順:**
1. プロジェクトルートに `CLAUDE.md` や `GEMINI.md` ファイルを作成します。
2. `CLAUDE.md` に以下の行を追加して、リポジトリの推奨プロンプトをインポートします。
```markdown
@./.ai/claude.prompt.md
```
またはGeminiの場合:
```markdown
@./.ai/gemini.prompt.md
```
3. インポート行の下に、独自の指示を追加できます(例:`常に日本語で応答してください。`)。
この方法により、エージェントに与える指示を各開発者が管理しつつ、リポジトリの推奨コンテキストを活用できます。`CLAUDE.md` および `GEMINI.md` は `.gitignore` に登録されているため、リポジトリにコミットされることはありません。
## Windows環境でのインストール
### Windowsでの動作に必要なプログラム
Python 3.10.xおよびGitが必要です。
- Python 3.10.x: https://www.python.org/downloads/windows/ からWindows installer (64-bit)をダウンロード
- git: https://git-scm.com/download/win から最新版をダウンロード
Python 3.11.x、3.12.xでも恐らく動作します未テスト
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
venvに限らずスクリプトの実行が可能になりますので注意してください。
@@ -57,11 +167,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
- 管理者のPowerShellを閉じます。
## Windows環境でのインストール
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
### インストール手順
PowerShellを使う場合、通常の管理者ではないPowerShellを開き以下を順に実行します。
@@ -72,20 +178,19 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
pip install --upgrade -r requirements.txt
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
コマンドプロンプトでも同一です。
注:`bitsandbytes==0.44.0``prodigyopt==1.0``lion-pytorch==0.0.6``requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
この例では PyTorch および xfomers は2.1.2CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください
注:`bitsandbytes`、`prodigyopt`、`lion-pytorch` は `requirements.txt` に含まれています
PyTorch 2.2以降を用いる場合は、`torch==2.1.2``torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
この例ではCUDA 12.4版をインストールします。異なるバージョンのCUDAを使用する場合は、適切なバージョンのPyTorchをインストールしてください。たとえばCUDA 12.1版の場合は `pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu121` としてください。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
@@ -102,6 +207,38 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### requirements.txtとPyTorchについて
PyTorchは環境によってバージョンが異なるため、requirements.txtには含まれていません。前述のインストール手順を参考に、環境に合わせてPyTorchをインストールしてください。
スクリプトはPyTorch 2.6.0でテストしています。PyTorch 2.6.0以降が必要です。
RTX 50シリーズGPUの場合、PyTorch 2.8.0とCUDA 12.8/12.9を使用してください。`requirements.txt`はこのバージョンでも動作します。
### xformersのインストールオプション
xformersをインストールするには、仮想環境を有効にした状態で以下のコマンドを実行してください。
```bash
pip install xformers --index-url https://download.pytorch.org/whl/cu124
```
必要に応じてCUDAバージョンを変更してください。一部のGPUアーキテクチャではxformersが利用できない場合があります。
## Linux/WSL2環境でのインストール
LinuxまたはWSL2環境でのインストール手順はWindows環境とほぼ同じです。`venv\Scripts\activate` の部分を `source venv/bin/activate` に変更してください。
※NVIDIAドライバやCUDAツールキットなどは事前にインストールしておいてください。
### DeepSpeedのインストール実験的、LinuxまたはWSL2のみ
DeepSpeedをインストールするには、仮想環境を有効にした状態で以下のコマンドを実行してください。
```bash
pip install deepspeed==0.16.7
```
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。
@@ -115,6 +252,10 @@ pip install --use-pep517 --upgrade -r requirements.txt
コマンドが成功すれば新しいバージョンが使用できます。
### PyTorchのアップグレード
PyTorchをアップグレードする場合は、[Windows環境でのインストール](#windows環境でのインストール)のセクションの`pip install`コマンドを参考にしてください。
## 謝意
LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
@@ -130,49 +271,3 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## その他の情報
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
<!--
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
-->
### 学習中のサンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` ネガティブプロンプト(次のオプションまで)
* `--w` 生成画像の幅を指定
* `--h` 生成画像の高さを指定
* `--d` 生成画像のシード値を指定
* `--l` 生成画像のCFGスケールを指定。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください
* `--g` 埋め込みガイダンス付きモデルFLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください
* `--s` 生成時のステップ数を指定
`( )` や `[ ]` などの重みづけも動作します。

1379
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,9 @@ koo="koo"
yos="yos"
wn="wn"
hime="hime"
OT="OT"
byt="byt"
tak="tak"
[files]
extend-exclude = ["_typos.toml", "venv"]

View File

@@ -1,9 +1,6 @@
Original Source by kohya-ss
First version: A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
First version:
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
Some parts are manually added.
Document is updated and maintained manually.
# Config Readme
@@ -267,10 +264,10 @@ The following command line argument options are ignored if a configuration file
* `--reg_data_dir`
* `--in_json`
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
For the command line options listed below, if an option is specified in both the command line arguments and the configuration file, the value from the configuration file will be given priority. Unless otherwise noted, the option names are the same.
| Command Line Argument Option | Prioritized Configuration File Option |
| ------------------------------- | ------------------------------------- |
| Command Line Argument Option | Corresponding Configuration File Option |
| ------------------------------- | --------------------------------------- |
| `--bucket_no_upscale` | |
| `--bucket_reso_steps` | |
| `--caption_dropout_every_n_epochs` | |

347
docs/fine_tune.md Normal file
View File

@@ -0,0 +1,347 @@
# Fine-tuning Guide
This document explains how to perform fine-tuning on various model architectures using the `*_train.py` scripts.
<details>
<summary>日本語</summary>
# Fine-tuning ガイド
このドキュメントでは、`*_train.py` スクリプトを用いた、各種モデルアーキテクチャのFine-tuningの方法について解説します。
</details>
### Difference between Fine-tuning and LoRA tuning
This repository supports two methods for additional model training: **Fine-tuning** and **LoRA (Low-Rank Adaptation)**. Each method has distinct features and advantages.
**Fine-tuning** is a method that retrains all (or most) of the weights of a pre-trained model.
- **Pros**: It can improve the overall expressive power of the model and is suitable for learning styles or concepts that differ significantly from the original model.
- **Cons**:
- It requires a large amount of VRAM and computational cost.
- The saved file size is large (same as the original model).
- It is prone to "overfitting," where the model loses the diversity of the original model if over-trained.
- **Corresponding scripts**: Scripts named `*_train.py`, such as `sdxl_train.py`, `sd3_train.py`, `flux_train.py`, and `lumina_train.py`.
**LoRA tuning** is a method that freezes the model's weights and only trains a small additional network called an "adapter."
- **Pros**:
- It allows for fast training with low VRAM and computational cost.
- It is considered resistant to overfitting because it trains fewer weights.
- The saved file (LoRA network) is very small, ranging from tens to hundreds of MB, making it easy to manage.
- Multiple LoRAs can be used in combination.
- **Cons**: Since it does not train the entire model, it may not achieve changes as significant as fine-tuning.
- **Corresponding scripts**: Scripts named `*_train_network.py`, such as `sdxl_train_network.py`, `sd3_train_network.py`, and `flux_train_network.py`.
| Feature | Fine-tuning | LoRA tuning |
|:---|:---|:---|
| **Training Target** | All model weights | Additional network (adapter) only |
| **VRAM/Compute Cost**| High | Low |
| **Training Time** | Long | Short |
| **File Size** | Large (several GB) | Small (few MB to hundreds of MB) |
| **Overfitting Risk** | High | Low |
| **Suitable Use Case** | Major style changes, concept learning | Adding specific characters or styles |
Generally, it is recommended to start with **LoRA tuning** if you want to add a specific character or style. **Fine-tuning** is a valid option for more fundamental style changes or aiming for a high-quality model.
<details>
<summary>日本語</summary>
### Fine-tuningとLoRA学習の違い
このリポジトリでは、モデルの追加学習手法として**Fine-tuning**と**LoRA (Low-Rank Adaptation)**学習の2種類をサポートしています。それぞれの手法には異なる特徴と利点があります。
**Fine-tuning**は、事前学習済みモデルの重み全体(または大部分)を再学習する手法です。
- **利点**: モデル全体の表現力を向上させることができ、元のモデルから大きく変化した画風やコンセプトの学習に適しています。
- **欠点**:
- 学習には多くのVRAMと計算コストが必要です。
- 保存されるファイルサイズが大きくなります(元のモデルと同じサイズ)。
- 学習させすぎると、元のモデルが持っていた多様性が失われる「過学習overfitting」に陥りやすい傾向があります。
- **対応スクリプト**: `sdxl_train.py`, `sd3_train.py`, `flux_train.py`, `lumina_train.py` など、`*_train.py` という命名規則のスクリプトが対応します。
**LoRA学習**は、モデルの重みは凍結(固定)したまま、「アダプター」と呼ばれる小さな追加ネットワークのみを学習する手法です。
- **利点**:
- 少ないVRAMと計算コストで高速に学習できます。
- 学習する重みが少ないため、過学習に強いとされています。
- 保存されるファイルLoRAネットワークは数十〜数百MBと非常に小さく、管理が容易です。
- 複数のLoRAを組み合わせて使用することも可能です。
- **欠点**: モデル全体を学習するわけではないため、Fine-tuningほどの大きな変化は期待できない場合があります。
- **対応スクリプト**: `sdxl_train_network.py`, `sd3_train_network.py`, `flux_train_network.py` など、`*_train_network.py` という命名規則のスクリプトが対応します。
| 特徴 | Fine-tuning | LoRA学習 |
|:---|:---|:---|
| **学習対象** | モデルの全重み | 追加ネットワーク(アダプター)のみ |
| **VRAM/計算コスト**| 大 | 小 |
| **学習時間** | 長 | 短 |
| **ファイルサイズ** | 大数GB | 小数MB〜数百MB |
| **過学習リスク** | 高 | 低 |
| **適した用途** | 大規模な画風変更、コンセプト学習 | 特定のキャラ、画風の追加学習 |
一般的に、特定のキャラクターや画風を追加したい場合は**LoRA学習**から試すことが推奨されます。より根本的な画風の変更や、高品質なモデルを目指す場合は**Fine-tuning**が有効な選択肢となります。
</details>
---
### Fine-tuning for each architecture
Fine-tuning updates the entire weights of the model, so it has different options and considerations than LoRA tuning. This section describes the fine-tuning scripts for major architectures.
The basic command structure is common to all architectures.
```bash
accelerate launch --mixed_precision bf16 {script_name}.py \
--pretrained_model_name_or_path <path_to_model> \
--dataset_config <path_to_config.toml> \
--output_dir <output_directory> \
--output_name <model_output_name> \
--save_model_as safetensors \
--max_train_steps 10000 \
--learning_rate 1e-5 \
--optimizer_type AdamW8bit
```
<details>
<summary>日本語</summary>
### 各アーキテクチャのFine-tuning
Fine-tuningはモデルの重み全体を更新するため、LoRA学習とは異なるオプションや考慮事項があります。ここでは主要なアーキテクチャごとのFine-tuningスクリプトについて説明します。
基本的なコマンドの構造は、どのアーキテクチャでも共通です。
```bash
accelerate launch --mixed_precision bf16 {script_name}.py \
--pretrained_model_name_or_path <path_to_model> \
--dataset_config <path_to_config.toml> \
--output_dir <output_directory> \
--output_name <model_output_name> \
--save_model_as safetensors \
--max_train_steps 10000 \
--learning_rate 1e-5 \
--optimizer_type AdamW8bit
```
</details>
#### SDXL (`sdxl_train.py`)
Performs fine-tuning for SDXL models. It is possible to train both the U-Net and the Text Encoders.
**Key Options:**
- `--train_text_encoder`: Includes the weights of the Text Encoders (CLIP ViT-L and OpenCLIP ViT-bigG) in the training. Effective for significant style changes or strongly learning specific concepts.
- `--learning_rate_te1`, `--learning_rate_te2`: Set individual learning rates for each Text Encoder.
- `--block_lr`: Divides the U-Net into 23 blocks and sets a different learning rate for each block. This allows for advanced adjustments, such as strengthening or weakening the learning of specific layers. (Not available in LoRA tuning).
**Command Example:**
```bash
accelerate launch --mixed_precision bf16 sdxl_train.py \
--pretrained_model_name_or_path "sd_xl_base_1.0.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "sdxl_finetuned" \
--train_text_encoder \
--learning_rate 1e-5 \
--learning_rate_te1 5e-6 \
--learning_rate_te2 2e-6
```
<details>
<summary>日本語</summary>
#### SDXL (`sdxl_train.py`)
SDXLモデルのFine-tuningを行います。U-NetとText Encoderの両方を学習させることが可能です。
**主要なオプション:**
- `--train_text_encoder`: Text EncoderCLIP ViT-LとOpenCLIP ViT-bigGの重みを学習対象に含めます。画風を大きく変えたい場合や、特定の概念を強く学習させたい場合に有効です。
- `--learning_rate_te1`, `--learning_rate_te2`: それぞれのText Encoderに個別の学習率を設定します。
- `--block_lr`: U-Netを23個のブロックに分割し、ブロックごとに異なる学習率を設定できます。特定の層の学習を強めたり弱めたりする高度な調整が可能です。LoRA学習では利用できません
**コマンド例:**
```bash
accelerate launch --mixed_precision bf16 sdxl_train.py \
--pretrained_model_name_or_path "sd_xl_base_1.0.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "sdxl_finetuned" \
--train_text_encoder \
--learning_rate 1e-5 \
--learning_rate_te1 5e-6 \
--learning_rate_te2 2e-6
```
</details>
#### SD3 (`sd3_train.py`)
Performs fine-tuning for Stable Diffusion 3 Medium models. SD3 consists of three Text Encoders (CLIP-L, CLIP-G, T5-XXL) and a MMDiT (equivalent to U-Net), which can be targeted for training.
**Key Options:**
- `--train_text_encoder`: Enables training for CLIP-L and CLIP-G.
- `--train_t5xxl`: Enables training for T5-XXL. T5-XXL is a very large model and requires a lot of VRAM for training.
- `--blocks_to_swap`: A memory optimization feature to reduce VRAM usage. It swaps some blocks of the MMDiT to CPU memory during training. Useful for using larger batch sizes in low VRAM environments. (Also available in LoRA tuning).
- `--num_last_block_to_freeze`: Freezes the weights of the last N blocks of the MMDiT, excluding them from training. Useful for maintaining model stability while focusing on learning in the lower layers.
**Command Example:**
```bash
accelerate launch --mixed_precision bf16 sd3_train.py \
--pretrained_model_name_or_path "sd3_medium.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "sd3_finetuned" \
--train_text_encoder \
--learning_rate 4e-6 \
--blocks_to_swap 10
```
<details>
<summary>日本語</summary>
#### SD3 (`sd3_train.py`)
Stable Diffusion 3 MediumモデルのFine-tuningを行います。SD3は3つのText EncoderCLIP-L, CLIP-G, T5-XXLとMMDiTU-Netに相当で構成されており、これらを学習対象にできます。
**主要なオプション:**
- `--train_text_encoder`: CLIP-LとCLIP-Gの学習を有効にします。
- `--train_t5xxl`: T5-XXLの学習を有効にします。T5-XXLは非常に大きなモデルのため、学習には多くのVRAMが必要です。
- `--blocks_to_swap`: VRAM使用量を削減するためのメモリ最適化機能です。MMDiTの一部のブロックを学習中にCPUメモリに退避スワップさせます。VRAMが少ない環境で大きなバッチサイズを使いたい場合に有効です。LoRA学習でも利用可能
- `--num_last_block_to_freeze`: MMDiTの最後のNブロックの重みを凍結し、学習対象から除外します。モデルの安定性を保ちつつ、下位層を中心に学習させたい場合に有効です。
**コマンド例:**
```bash
accelerate launch --mixed_precision bf16 sd3_train.py \
--pretrained_model_name_or_path "sd3_medium.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "sd3_finetuned" \
--train_text_encoder \
--learning_rate 4e-6 \
--blocks_to_swap 10
```
</details>
#### FLUX.1 (`flux_train.py`)
Performs fine-tuning for FLUX.1 models. FLUX.1 is internally composed of two Transformer blocks (Double Blocks, Single Blocks).
**Key Options:**
- `--blocks_to_swap`: Similar to SD3, this feature swaps Transformer blocks to the CPU for memory optimization.
- `--blockwise_fused_optimizers`: An experimental feature that aims to streamline training by applying individual optimizers to each block.
**Command Example:**
```bash
accelerate launch --mixed_precision bf16 flux_train.py \
--pretrained_model_name_or_path "FLUX.1-dev.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "flux1_finetuned" \
--learning_rate 1e-5 \
--blocks_to_swap 18
```
<details>
<summary>日本語</summary>
#### FLUX.1 (`flux_train.py`)
FLUX.1モデルのFine-tuningを行います。FLUX.1は内部的に2つのTransformerブロックDouble Blocks, Single Blocksで構成されています。
**主要なオプション:**
- `--blocks_to_swap`: SD3と同様に、メモリ最適化のためにTransformerブロックをCPUにスワップする機能です。
- `--blockwise_fused_optimizers`: 実験的な機能で、各ブロックに個別のオプティマイザを適用し、学習を効率化することを目指します。
**コマンド例:**
```bash
accelerate launch --mixed_precision bf16 flux_train.py \
--pretrained_model_name_or_path "FLUX.1-dev.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "flux1_finetuned" \
--learning_rate 1e-5 \
--blocks_to_swap 18
```
</details>
#### Lumina (`lumina_train.py`)
Performs fine-tuning for Lumina-Next DiT models.
**Key Options:**
- `--use_flash_attn`: Enables Flash Attention to speed up computation.
- `lumina_train.py` is relatively new, and many of its options are shared with other scripts. Training can be performed following the basic command pattern.
**Command Example:**
```bash
accelerate launch --mixed_precision bf16 lumina_train.py \
--pretrained_model_name_or_path "Lumina-Next-DiT-B.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "lumina_finetuned" \
--learning_rate 1e-5
```
<details>
<summary>日本語</summary>
#### Lumina (`lumina_train.py`)
Lumina-Next DiTモデルのFine-tuningを行います。
**主要なオプション:**
- `--use_flash_attn`: Flash Attentionを有効にし、計算を高速化します。
- `lumina_train.py`は比較的新しく、オプションは他のスクリプトと共通化されている部分が多いです。基本的なコマンドパターンに従って学習を行えます。
**コマンド例:**
```bash
accelerate launch --mixed_precision bf16 lumina_train.py \
--pretrained_model_name_or_path "Lumina-Next-DiT-B.safetensors" \
--dataset_config "dataset_config.toml" \
--output_dir "output" \
--output_name "lumina_finetuned" \
--learning_rate 1e-5
```
</details>
---
### Differences between Fine-tuning and LoRA tuning per architecture
| Architecture | Key Features/Options Specific to Fine-tuning | Main Differences from LoRA tuning |
|:---|:---|:---|
| **SDXL** | `--block_lr` | Only fine-tuning allows for granular control over the learning rate for each U-Net block. |
| **SD3** | `--train_text_encoder`, `--train_t5xxl`, `--num_last_block_to_freeze` | Only fine-tuning can train the entire Text Encoders. LoRA only trains the adapter parts. |
| **FLUX.1** | `--blockwise_fused_optimizers` | Since fine-tuning updates the entire model's weights, more experimental optimizer options are available. |
| **Lumina** | (Few specific options) | Basic training options are common, but fine-tuning differs in that it updates the entire model's foundation. |
<details>
<summary>日本語</summary>
### アーキテクチャごとのFine-tuningとLoRA学習の違い
| アーキテクチャ | Fine-tuning特有の主要機能・オプション | LoRA学習との主な違い |
|:---|:---|:---|
| **SDXL** | `--block_lr` | U-Netのブロックごとに学習率を細かく制御できるのはFine-tuningのみです。 |
| **SD3** | `--train_text_encoder`, `--train_t5xxl`, `--num_last_block_to_freeze` | Text Encoder全体を学習対象にできるのはFine-tuningです。LoRAではアダプター部分のみ学習します。 |
| **FLUX.1** | `--blockwise_fused_optimizers` | Fine-tuningではモデル全体の重みを更新するため、より実験的なオプティマイザの選択肢が用意されています。 |
| **Lumina** | (特有のオプションは少ない) | 基本的な学習オプションは共通ですが、Fine-tuningはモデルの基盤全体を更新する点で異なります。 |
</details>

View File

@@ -550,24 +550,34 @@ You can calculate validation loss during training using a validation dataset to
To set up validation, add a `validation_split` and optionally `validation_seed` to your dataset configuration TOML file.
```toml
[[datasets]]
validation_seed = 42 # [Optional] Validation seed, otherwise uses training seed for validation split .
enable_bucket = true
resolution = [1024, 1024]
validation_seed = 42 # [Optional] Validation seed, otherwise uses training seed for validation split .
[[datasets]]
[[datasets.subsets]]
# This directory will use 100% of the images for training
image_dir = "path/to/image/directory"
validation_split = 0.1 # Split between 0.0 and 1.0 where 1.0 will use the full subset as a validation dataset
[[datasets]]
validation_split = 0.1 # Split between 0.0 and 1.0 where 1.0 will use the full subset as a validation dataset
[[datasets.subsets]]
# This directory will split 10% to validation and 90% to training
image_dir = "path/to/image/second-directory"
[[datasets]]
validation_split = 1.0 # Will use this full subset as a validation subset.
[[datasets.subsets]]
# This directory will use the 100% to validation and 0% to training
image_dir = "path/to/image/full_validation"
validation_split = 1.0 # Will use this full subset as a validation subset.
```
**Notes:**
* Validation loss calculation uses fixed timestep sampling and random seeds to reduce loss variation due to randomness for more stable evaluation.
* Currently, validation loss is not supported when using `--blocks_to_swap` or Schedule-Free optimizers (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`).
* Currently, validation loss is not supported when using Schedule-Free optimizers (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`).
<details>
<summary>日本語</summary>
@@ -631,6 +641,40 @@ interpolation_type = "lanczos" # Example: Use Lanczos interpolation
</details>
### 7.3. Other Training Options / その他の学習オプション
- **`--controlnet_model_name_or_path`**: Specifies the path to a ControlNet model compatible with FLUX.1. This allows for training a LoRA that works in conjunction with ControlNet. This is an advanced feature and requires a compatible ControlNet model.
- **`--loss_type`**: Specifies the loss function for training. The default is `l2`.
- `l1`: L1 loss.
- `l2`: L2 loss (mean squared error).
- `huber`: Huber loss.
- `smooth_l1`: Smooth L1 loss.
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: These are parameters for Huber loss. They are used when `--loss_type` is set to `huber` or `smooth_l1`.
- **`--t5xxl_max_token_length`**: Specifies the maximum token length for the T5-XXL text encoder. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: These options allow you to adjust the loss weighting for each timestep. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. For details, refer to the [`sdxl_train_network.md` guide](sdxl_train_network.md).
<details>
<summary>日本語</summary>
- **`--controlnet_model_name_or_path`**: FLUX.1互換のControlNetモデルへのパスを指定します。これにより、ControlNetと連携して動作するLoRAを学習できます。これは高度な機能であり、互換性のあるControlNetモデルが必要です。
- **`--loss_type`**: 学習に用いる損失関数を指定します。デフォルトは `l2` です。
- `l1`: L1損失。
- `l2`: L2損失平均二乗誤差
- `huber`: Huber損失。
- `smooth_l1`: Smooth L1損失。
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: これらはHuber損失のパラメータです。`--loss_type` が `huber` または `smooth_l1` の場合に使用されます。
- **`--t5xxl_max_token_length`**: T5-XXLテキストエンコーダの最大トークン長を指定します。詳細は [`sd3_train_network.md` ガイド](sd3_train_network.md) を参照してください。
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: これらのオプションは、各タイムステップの損失の重み付けを調整するために使用されます。詳細は [`sd3_train_network.md` ガイド](sd3_train_network.md) を参照してください。
- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップを融合してVRAM使用量を削減します。詳細は [`sdxl_train_network.md` ガイド](sdxl_train_network.md) を参照してください。
</details>
## 8. Related Tools / 関連ツール
Several related scripts are provided for models trained with `flux_train_network.py` and to assist with the training process:

View File

@@ -1,30 +1,24 @@
SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNetv1.0のみ動作確認などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
SD 1.x、2.x、およびSDXLのモデル、当リポジトリで学習したLoRA、ControlNet、ControlNet-LLLiteなどに対応した、独自の推論(画像生成)スクリプトです。コマンドラインから用います。
# 概要
* Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
* SD 1.xおよび2.x (base/v-parameterization)モデルに対応。
* 独自の推論(画像生成)スクリプト。
* SD 1.x2.x (base/v-parameterization)、およびSDXLモデルに対応。
* txt2img、img2img、inpaintingに対応。
* 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
* プロンプト1行あたりの生成枚数を指定可能。
* 全体の繰り返し回数を指定可能。
* `fp16`だけでなく`bf16`にも対応。
* xformersに対応し高速生成が可能
* xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
* xformers、SDPAScaled Dot-Product Attentionに対応。
* プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
* Diffusersの各種samplerに対応Web UIよりもsampler数は少ないです
* Diffusersの各種samplerに対応。
* Text Encoderのclip skip最後からn番目の層の出力を用いるに対応。
* VAEの別途読み込み。
* CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
* Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません
* LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
* Text EncoderとU-Netで別の適用率を指定することはできません
* Attention Coupleに対応。
* ControlNet v1.0に対応。
* VAEの別途読み込み、VAEのバッチ処理やスライスによる省メモリ化に対応
* Highres. fix独自実装およびGradual Latent、upscale対応。
* LoRA、DyLoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応
* Attention Couple、Regional LoRAに対応。
* ControlNet (v1.0/v1.1)、ControlNet-LLLiteに対応
* 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
* 個人的に欲しくなった機能をいろいろ追加。
機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
# 基本的な使い方
@@ -33,18 +27,20 @@ SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、Control
以下のように入力してください。
```batchfile
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
python gen_img.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
```
`--ckpt`オプションにモデルStable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ`--outdir`オプションに画像の出力先フォルダを指定します。
`--xformers`オプションでxformersの使用を指定しますxformersを使わない場合は外してください`--fp16`オプションでfp16精度での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16bfloat16での推論を行うこともできます。
`--xformers`オプションでxformersの使用を指定します。`--fp16`オプションでfp16精度での推論を行います。RTX 30系以降のGPUでは `--bf16`オプションでbf16bfloat16での推論を行うこともできます。
`--interactive`オプションで対話モードを指定しています。
Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます
SDXLモデルを使う場合は`--sdxl`オプションを追加してください
`--v2``--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
`Type prompt:`と表示されたらプロンプトを入力してください。
@@ -59,7 +55,7 @@ Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う
以下のように入力します実際には1行で入力します
```batchfile
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
python gen_img.py --ckpt <モデル名> --outdir <画像出力先>
--xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
```
@@ -72,7 +68,7 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
以下のように入力します。
```batchfile
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
python gen_img.py --ckpt <モデル名> --outdir <画像出力先>
--xformers --fp16 --from_file <プロンプトファイル名>
```
@@ -96,13 +92,29 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
- `--ckpt <モデル名>`:モデル名を指定します。`--ckpt`オプションは必須です。Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ、Hugging FaceのモデルIDを指定できます。
- `--v1`Stable Diffusion 1.x系のモデルを使う場合に指定します。これがデフォルトの動作です。
- `--v2`Stable Diffusion 2.x系のモデルを使う場合に指定します。1.x系の場合には指定不要です。
- `--sdxl`Stable Diffusion XLモデルを使う場合に指定します。
- `--v_parameterization`v-parameterizationを使うモデルを使う場合に指定します`768-v-ema.ckpt`およびそこからの追加学習モデル、Waifu Diffusion v1.5など)。
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
`--v2``--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
- `--vae`使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
- `--zero_terminal_snr`noise schedulerのbetasを修正して、zero terminal SNRを強制します。
- `--pyramid_noise_prob`:ピラミッドノイズを適用する確率を指定します。
- `--pyramid_noise_discount_range`:ピラミッドノイズの割引率の範囲を指定します。
- `--noise_offset_prob`:ノイズオフセットを適用する確率を指定します。
- `--noise_offset_range`:ノイズオフセットの範囲を指定します。
- `--vae`:使用する VAE を指定します。未指定時はモデル内の VAE を使用します。
- `--tokenizer_cache_dir`:トークナイザーのキャッシュディレクトリを指定します(オフライン利用のため)。
## 画像生成と出力
@@ -112,6 +124,10 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
- `--from_file <プロンプトファイル名>`プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。なお画像サイズやguidance scaleはプロンプトオプション後述で指定できます。
- `--from_module <モジュールファイル>`Pythonモジュールからプロンプトを読み込みます。モジュールは`get_prompter(args, pipe, networks)`関数を実装している必要があります。
- `--prompter_module_args`prompterモジュールに渡す追加の引数を指定します。
- `--W <画像幅>`:画像の幅を指定します。デフォルトは`512`です。
- `--H <画像高さ>`:画像の高さを指定します。デフォルトは`512`です。
@@ -120,30 +136,59 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
- `--scale <ガイダンススケール>`unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です後ろの三つはk_lms、k_euler、k_euler_aでも指定できます
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。
`ddim`, `pndm`, `lms`, `euler`, `euler_a`, `heun`, `dpm_2`, `dpm_2_a`, `dpmsolver`, `dpmsolver++`, `dpmsingle`, `k_lms`, `k_euler`, `k_euler_a`, `k_dpm_2`, `k_dpm_2_a` が指定可能です。
- `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
- `--images_per_prompt <生成枚数>`プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
- `--clip_skip <スキップ数>`CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
- `--clip_skip <スキップ数>`CLIPの後ろから何番目の層を使うかを指定します。デフォルトはSD1/2の場合1、SDXLの場合2です。
- `--max_embeddings_multiples <倍数>`CLIPの入出力長をデフォルト75の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
- `--negative_scale` : uncoditioningのguidance scaleを個別に指定します。[gcem156氏のこちらの記事](https://note.com/gcem156/n/ne9a53e4a6f43)を参考に実装したものです。
- `--emb_normalize_mode`embedding正規化モードを指定します。"original"(デフォルト)、"abs"、"none"から選択できます。プロンプトの重みの正規化方法に影響します。
- `--force_scheduler_zero_steps_offset`:スケジューラのステップオフセットを、スケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにします。
## SDXL固有のオプション
SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコンディショニングオプションが利用できます:
- `--original_height`SDXL コンディショニング用の元の高さを指定します。これはモデルの対象解像度の理解に影響します。
- `--original_width`SDXL コンディショニング用の元の幅を指定します。これはモデルの対象解像度の理解に影響します。
- `--original_height_negative`SDXL ネガティブコンディショニング用の元の高さを指定します。
- `--original_width_negative`SDXL ネガティブコンディショニング用の元の幅を指定します。
- `--crop_top`SDXL コンディショニング用のクロップ上オフセットを指定します。
- `--crop_left`SDXL コンディショニング用のクロップ左オフセットを指定します。
## メモリ使用量や生成速度の調整
- `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
- `--vae_batch_size <VAEのバッチサイズ>`VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
- `--vae_batch_size <VAEのバッチサイズ>`VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。1未満の値を指定すると、バッチサイズに対する比率として扱われます。
VAEのほうがメモリを多く消費するため、デイジング後stepが100%になった後でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてください。
- `--vae_slices <スライス数>`VAE処理時に画像をスライスに分割してVRAM使用量を削減します。Noneデフォルトで分割なし。16や32のような値が推奨されます。有効にすると処理が遅くなりますが、VRAM使用量が少なくなります。
- `--no_half_vae`VAE処理でfp16/bf16精度の使用を防ぎます。代わりにfp32を使用します。VAE関連の問題やアーティファクトが発生した場合に使用してください。
- `--xformers`xformersを使う場合に指定します。
- `--fp16`fp16単精度での推論を行います。`fp16``bf16`をどちらも指定しない場合はfp32単精度での推論を行います。
- `--sdpa`最適化のためにPyTorch 2のscaled dot-product attentionを使用します。
- `--bf16`bf16bfloat16での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる真っ黒の画像になる可能性が低いようです
- `--diffusers_xformers`Diffusers経由でxformersを使用しますHypernetworksと互換性がありません
- `--fp16`fp16半精度での推論を行います。`fp16``bf16`をどちらも指定しない場合はfp32単精度での推論を行います。
- `--bf16`bf16bfloat16での推論を行います。RTX 30系以降のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。SDXLでは`fp16`よりも`bf16`のほうが推論結果がNaNになる真っ黒の画像になる可能性が低いようです。
## 追加ネットワークLoRA等の使用
@@ -157,12 +202,18 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
- `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。
- `--network_regional_mask_max_color_codes`リージョナルマスクに使用する色コードの最大数を指定します。指定されていない場合、マスクはチャンネルごとに適用されます。Regional LoRAと組み合わせて、マスク内の色で定義できるリージョン数を制御するために使用されます。
- `--network_args`key=value形式でネットワークモジュールに渡す追加引数を指定します。例: `--network_args "alpha=1.0,dropout=0.1"`
- `--network_merge_n_models`:ネットワークマージを使用する場合、マージするモデル数を指定します(全ての読み込み済みネットワークをマージする代わりに)。
# 主なオプションの指定例
次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
```batchfile
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
python gen_img.py --ckpt model.ckpt --outdir outputs
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
--steps 32 --batch_size 4 --images_per_prompt 64
--prompt "beautiful flowers --n monochrome"
@@ -171,7 +222,7 @@ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
```batchfile
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
python gen_img.py --ckpt model.ckpt --outdir outputs
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
--steps 32 --batch_size 4 --images_per_prompt 10
--from_file prompts.txt
@@ -180,7 +231,7 @@ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
Textual Inversion後述およびLoRAの使用例です。
```batchfile
python gen_img_diffusers.py --ckpt model.safetensors
python gen_img.py --ckpt model.safetensors
--scale 8 --steps 48 --outdir txt2img --xformers
--W 512 --H 768 --fp16 --sampler k_euler_a
--textual_inversion_embeddings goodembed.safetensors negprompt.pt
@@ -216,6 +267,22 @@ python gen_img_diffusers.py --ckpt model.safetensors
- `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
- `--ow`SDXLのoriginal_widthを指定します。
- `--oh`SDXLのoriginal_heightを指定します。
- `--nw`SDXLのoriginal_width_negativeを指定します。
- `--nh`SDXLのoriginal_height_negativeを指定します。
- `--ct`SDXLのcrop_topを指定します。
- `--cl`SDXLのcrop_leftを指定します。
- `--c`CLIPプロンプトを指定します。
- `--f`:生成ファイル名を指定します。
※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
例:
@@ -225,6 +292,21 @@ python gen_img_diffusers.py --ckpt model.safetensors
![image](https://user-images.githubusercontent.com/52813779/235343446-25654172-fff4-4aaf-977a-20d262b51676.png)
# プロンプトのワイルドカード (Dynamic Prompts)
Dynamic Prompts (Wildcard) 記法に対応しています。Web UIの拡張機能等と完全に同じではありませんが、以下の機能が利用可能です。
- `{A|B|C}` : A, B, C の中からランダムに1つを選択します。
- `{e$$A|B|C}` : A, B, C のすべてを順に利用します(全列挙)。プロンプト内に複数の `{e$$...}` がある場合、すべての組み合わせが生成されます。
- 例:`{e$$red|blue} flower, {e$$1girl|2girls}``red flower, 1girl`, `red flower, 2girls`, `blue flower, 1girl`, `blue flower, 2girls` の4枚が生成されます。
- `{n$$A|B|C}` : A, B, C の中から n 個をランダムに選択して結合します。
- 例:`{2$$A|B|C}``A, B``B, C` など。
- `{n-m$$A|B|C}` : A, B, C の中から n 個から m 個をランダムに選択して結合します。
- `{$$sep$$A|B|C}` : 選択された項目を sep で結合します(デフォルトは `, `)。
- 例:`{2$$ and $$A|B|C}``A and B` など。
これらは組み合わせて利用可能です。
# img2img
## オプション
@@ -235,12 +317,14 @@ python gen_img_diffusers.py --ckpt model.safetensors
- `--sequential_file_name`:ファイル名を連番にするかどうかを指定します。指定すると生成されるファイル名が`im_000001.png`からの連番になります。
- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名と同じになります
- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名の前に追加されますimg2imgモード用
- `--clip_vision_strength`指定した強度でimg2img用のCLIP Vision Conditioningを有効にします。CLIP Visionモデルを使用して入力画像からのコンディショニングを強化します。
## コマンドラインからの実行例
```batchfile
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
--outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
--image_path template.png --strength 0.8
--prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
@@ -281,10 +365,6 @@ img2img時にコマンドラインオプションの`--W`と`--H`で生成画像
モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル画像埋め込みは非対応を利用できます
## Extended Textual Inversion
`--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
## Highres. fix
AUTOMATIC1111氏のWeb UIにある機能の類似機能です独自実装のためもしかしたらいろいろ異なるかもしれません。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
@@ -299,6 +379,8 @@ img2imgと併用できません。
- `--highres_fix_steps`1st stageの画像のステップ数を指定します。デフォルトは`28`です。
- `--highres_fix_strength`1st stageのimg2img時のstrengthを指定します。省略時は`--strength`と同じ値になります。
- `--highres_fix_save_1st`1st stageの画像を保存するかどうかを指定します。
- `--highres_fix_latents_upscaling`指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingしますbilinearのみ対応。未指定時は画像をLANCZOS4でupscalingします。
@@ -306,12 +388,14 @@ img2imgと併用できません。
- `--highres_fix_upscaler`2nd stageに任意のupscalerを利用します。現在は`--highres_fix_upscaler tools.latent_upscaler` のみ対応しています。
- `--highres_fix_upscaler_args``--highres_fix_upscaler`で指定したupscalerに渡す引数を指定します。
`tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。
`tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。
- `--highres_fix_disable_control_net`Highres fixの2nd stageでControlNetを無効にします。デフォルトでは、ControlNetは両ステージで使用されます。
コマンドラインの例です。
```batchfile
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
--n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
--steps 48 --sampler ddim --fp16
--xformers
@@ -319,6 +403,34 @@ python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
--highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5
```
## Deep Shrink
Deep Shrinkは、異なるタイムステップで異なる深度のUNetを使用して生成プロセスを最適化する技術です。生成品質と効率を向上させることができます。
以下のオプションがあります:
- `--ds_depth_1`第1フェーズでこの深度のDeep Shrinkを有効にします。有効な値は0から8です。
- `--ds_timesteps_1`このタイムステップまでDeep Shrink深度1を適用します。デフォルトは650です。
- `--ds_depth_2`Deep Shrinkの第2フェーズの深度を指定します。
- `--ds_timesteps_2`このタイムステップまでDeep Shrink深度2を適用します。デフォルトは650です。
- `--ds_ratio`Deep Shrinkでのダウンサンプリングの比率を指定します。デフォルトは0.5です。
これらのパラメータはプロンプトオプションでも指定できます:
- `--dsd1`プロンプトからDeep Shrink深度1を指定します。
- `--dst1`プロンプトからDeep Shrinkタイムステップ1を指定します。
- `--dsd2`プロンプトからDeep Shrink深度2を指定します。
- `--dst2`プロンプトからDeep Shrinkタイムステップ2を指定します。
- `--dsr`プロンプトからDeep Shrink比率を指定します。
## ControlNet
現在はControlNet 1.0のみ動作確認しています。プリプロセスはCannyのみサポートしています。
@@ -333,19 +445,33 @@ python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
- `--control_net_preps`ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
- `--control_net_weights`ControlNetの適用時の重みを指定します`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
- `--control_net_multipliers`ControlNetの適用時の重みを指定します`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
- `--control_net_ratios`ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
コマンドラインの例です。
```batchfile
python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
python gen_img.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
--W 512 --H 768 --bf16 --sampler k_euler_a
--control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
--control_net_models diff_control_sd15_canny.safetensors --control_net_multipliers 1.0
--guide_image_path guide.png --control_net_ratios 1.0 --interactive
```
## ControlNet-LLLite
ControlNet-LLLiteは、類似の誘導目的に使用できるControlNetの軽量な代替手段です。
以下のオプションがあります:
- `--control_net_lllite_models`ControlNet-LLLiteモデルファイルを指定します。
- `--control_net_multipliers`ControlNet-LLLiteの倍率を指定します重みに類似
- `--control_net_ratios`ControlNet-LLLiteを適用するステップの比率を指定します。
注意ControlNetとControlNet-LLLiteは同時に使用できません。
## Attention Couple + Reginal LoRA
プロンプトをいくつかの部分に分割し、それぞれのプロンプトを画像内のどの領域に適用するかを指定できる機能です。個別のオプションはありませんが、`mask_path`とプロンプトで指定します。
@@ -370,70 +496,6 @@ ControlNetと組み合わせることも可能です細かい位置指定に
LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
## CLIP Guided Stable Diffusion
DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします私のざっくりとした理解です。大きめのCLIPを使いますのでVRAM使用量はかなり増加しVRAM 8GBでは512*512でも厳しいかもしれません、生成時間も掛かります。
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
`--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
デフォルトではプロンプトの先頭75トークン重みづけの特殊文字を除くがCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できますたとえばCLIPはDreamBoothのidentifier識別子や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます
コマンドラインの例です。
```batchfile
python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
--interactive --clip_guidance_scale 100
```
## CLIP Image Guided Stable Diffusion
テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像ファイルまたはフォルダを指定してください。
コマンドラインの例です。
```batchfile
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
```
### VGG16 Guided Stable Diffusion
指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします通常の生成では画像がぼやけた感じになります。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
`--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像ファイルまたはフォルダを指定してください。
複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
コマンドラインの例です。
```batchfile
python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
--xformers --sampler ddim --fp16 --W 512 --H 704
--batch_size 1 --images_per_prompt 1
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
--n lowres, bad anatomy, bad hands, error, missing fingers,
cropped, worst quality, low quality, normal quality,
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
--strength 0.8 --image_path ..\src_image
--vgg16_guidance_scale 100 --guide_image_path ..\src_image
```
`--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できますデフォルトは20でconv4-2のReLUです。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
![image](https://user-images.githubusercontent.com/52813779/235343813-3c1f0d7a-4fb3-4274-98e4-b92d76b551df.png)
# その他のオプション
- `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
@@ -450,34 +512,22 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
- `--opt_channels_last` : 推論時にテンソルのチャンネルを最後に配置します。場合によっては高速化されることがあります。
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
- `--shuffle_prompts`:繰り返し時にプロンプトの順序をシャッフルします。`--from_file`で複数のプロンプトを使用する場合に便利です。
- `--network_show_meta`:追加ネットワークのメタデータを表示します。
---
# About Gradual Latent
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
It is more effective with SD 1.5. It is quite subtle with SDXL.
# Gradual Latent について
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py``gen_img_diffusers.py` に以下のオプションが追加されています。
latentのサイズを徐々に大きくしていくHires fixです。
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
- `--gradual_latent_s_noise`Gradual LatentのS_noiseパラメータを指定します。デフォルトは1.0です。
- `--gradual_latent_unsharp_params`Gradual Latentのアンシャープマスクパラメータをksize,sigma,strength,target-x形式で指定しますtarget-x: 1=True, 0=False。推奨値`3,0.5,0.5,1`または`3,1.0,1.0,0`。
それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。

View File

@@ -4,31 +4,22 @@ This is an inference (image generation) script that supports SD 1.x and 2.x mode
# Overview
* Inference (image generation) script.
* Supports SD 1.x and 2.x (base/v-parameterization) models.
* Supports SD 1.x, 2.x (base/v-parameterization), and SDXL models.
* Supports txt2img, img2img, and inpainting.
* Supports interactive mode, prompt reading from files, and continuous generation.
* The number of images generated per prompt line can be specified.
* The total number of repetitions can be specified.
* Supports not only `fp16` but also `bf16`.
* Supports xformers for high-speed generation.
* Although xformers are used for memory-saving generation, it is not as optimized as Automatic 1111's Web UI, so it uses about 6GB of VRAM for 512*512 image generation.
* Supports xformers and SDPA (Scaled Dot-Product Attention).
* Extension of prompts to 225 tokens. Supports negative prompts and weighting.
* Supports various samplers from Diffusers (fewer samplers than Web UI).
* Supports various samplers from Diffusers.
* Supports clip skip (uses the output of the nth layer from the end) of Text Encoder.
* Separate loading of VAE.
* Supports CLIP Guided Stable Diffusion, VGG16 Guided Stable Diffusion, Highres. fix, and upscale.
* Highres. fix is an original implementation that has not confirmed the Web UI implementation at all, so the output results may differ.
* LoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging.
* It is not possible to specify different application rates for Text Encoder and U-Net.
* Supports Attention Couple.
* Supports ControlNet v1.0.
* Supports Deep Shrink for optimizing generation at different depths.
* Supports Gradual Latent for progressive upscaling during generation.
* Supports CLIP Vision Conditioning for img2img.
* Separate loading of VAE, supports VAE batch processing and slicing for memory saving.
* Highres. fix (original implementation and Gradual Latent), upscale support.
* LoRA, DyLoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging.
* Supports Attention Couple, Regional LoRA.
* Supports ControlNet (v1.0/v1.1), ControlNet-LLLite.
* It is not possible to switch models midway, but it can be handled by creating a batch file.
* Various personally desired features have been added.
Since not all tests are performed when adding features, it is possible that previous features may be affected and some features may not work. Please let us know if you have any problems.
# Basic Usage
@@ -100,14 +91,30 @@ Specify from the command line.
- `--ckpt <model_name>`: Specifies the model name. The `--ckpt` option is mandatory. You can specify a Stable Diffusion checkpoint file, a Diffusers model folder, or a Hugging Face model ID.
- `--v1`: Specify when using Stable Diffusion 1.x series models. This is the default behavior.
- `--v2`: Specify when using Stable Diffusion 2.x series models. Not required for 1.x series.
- `--sdxl`: Specify when using Stable Diffusion XL models.
- `--v_parameterization`: Specify when using models that use v-parameterization (`768-v-ema.ckpt` and models with additional training from it, Waifu Diffusion v1.5, etc.).
If the `--v2` specification is incorrect, an error will occur when loading the model. If the `--v_parameterization` specification is incorrect, a brown image will be displayed.
If the `--v2` or `--sdxl` specification is incorrect, an error will occur when loading the model. If the `--v_parameterization` specification is incorrect, a brown image will be displayed.
- `--zero_terminal_snr`: Modifies the noise scheduler betas to enforce zero terminal SNR.
- `--pyramid_noise_prob`: Specifies the probability of applying pyramid noise.
- `--pyramid_noise_discount_range`: Specifies the discount range for pyramid noise.
- `--noise_offset_prob`: Specifies the probability of applying noise offset.
- `--noise_offset_range`: Specifies the range of noise offset.
- `--vae`: Specifies the VAE to use. If not specified, the VAE in the model will be used.
- `--tokenizer_cache_dir`: Specifies the cache directory for the tokenizer (for offline usage).
## Image Generation and Output
- `--interactive`: Operates in interactive mode. Images are generated when prompts are entered.
@@ -118,6 +125,8 @@ Specify from the command line.
- `--from_module <module_file>`: Loads prompts from a Python module. The module should implement a `get_prompter(args, pipe, networks)` function.
- `--prompter_module_args`: Specifies additional arguments to pass to the prompter module.
- `--W <image_width>`: Specifies the width of the image. The default is `512`.
- `--H <image_height>`: Specifies the height of the image. The default is `512`.
@@ -126,13 +135,14 @@ Specify from the command line.
- `--scale <guidance_scale>`: Specifies the unconditional guidance scale. The default is `7.5`.
- `--sampler <sampler_name>`: Specifies the sampler. The default is `ddim`. ddim, pndm, dpmsolver, dpmsolver+++, lms, euler, euler_a provided by Diffusers can be specified (the last three can also be specified as k_lms, k_euler, k_euler_a).
- `--sampler <sampler_name>`: Specifies the sampler. The default is `ddim`.
`ddim`, `pndm`, `lms`, `euler`, `euler_a`, `heun`, `dpm_2`, `dpm_2_a`, `dpmsolver`, `dpmsolver++`, `dpmsingle`, `k_lms`, `k_euler`, `k_euler_a`, `k_dpm_2`, `k_dpm_2_a` can be specified.
- `--outdir <image_output_destination_folder>`: Specifies the output destination for images.
- `--images_per_prompt <number_of_images_to_generate>`: Specifies the number of images to generate per prompt. The default is `1`.
- `--clip_skip <number_of_skips>`: Specifies which layer from the end of CLIP to use. If omitted, the last layer is used.
- `--clip_skip <number_of_skips>`: Specifies which layer from the end of CLIP to use. Default is 1 for SD1/2, 2 for SDXL.
- `--max_embeddings_multiples <multiplier>`: Specifies how many times the CLIP input/output length should be multiplied by the default (75). If not specified, it remains 75. For example, specifying 3 makes the input/output length 225.
@@ -140,6 +150,24 @@ Specify from the command line.
- `--emb_normalize_mode`: Specifies the embedding normalization mode. Options are "original" (default), "abs", and "none". This affects how prompt weights are normalized.
- `--force_scheduler_zero_steps_offset`: Forces the scheduler step offset to zero regardless of the `steps_offset` value in the scheduler configuration.
## SDXL-Specific Options
When using SDXL models (with `--sdxl` flag), additional conditioning options are available:
- `--original_height`: Specifies the original height for SDXL conditioning. This affects the model's understanding of the target resolution.
- `--original_width`: Specifies the original width for SDXL conditioning. This affects the model's understanding of the target resolution.
- `--original_height_negative`: Specifies the original height for SDXL negative conditioning.
- `--original_width_negative`: Specifies the original width for SDXL negative conditioning.
- `--crop_top`: Specifies the crop top offset for SDXL conditioning.
- `--crop_left`: Specifies the crop left offset for SDXL conditioning.
## Adjusting Memory Usage and Generation Speed
- `--batch_size <batch_size>`: Specifies the batch size. The default is `1`. A larger batch size consumes more memory but speeds up generation.
@@ -149,12 +177,14 @@ Specify from the command line.
- `--vae_slices <number_of_slices>`: Splits the image into slices for VAE processing to reduce VRAM usage. None (default) for no splitting. Values like 16 or 32 are recommended. Enabling this is slower but uses less VRAM.
- `--no_half_vae`: Prevents using fp16/bf16 precision for VAE processing. Uses fp32 instead.
- `--no_half_vae`: Prevents using fp16/bf16 precision for VAE processing. Uses fp32 instead. Use this if you encounter VAE-related issues or artifacts.
- `--xformers`: Specify when using xformers.
- `--sdpa`: Use scaled dot-product attention in PyTorch 2 for optimization.
- `--diffusers_xformers`: Use xformers via Diffusers (note: incompatible with Hypernetworks).
- `--fp16`: Performs inference in fp16 (single precision). If neither `fp16` nor `bf16` is specified, inference is performed in fp32 (single precision).
- `--bf16`: Performs inference in bf16 (bfloat16). Can only be specified for RTX 30 series GPUs. The `--bf16` option will cause an error on GPUs other than the RTX 30 series. It seems that `bf16` is less likely to result in NaN (black image) inference results than `fp16`.
@@ -173,6 +203,10 @@ Specify from the command line.
- `--network_regional_mask_max_color_codes`: Specifies the maximum number of color codes to use for regional masks. If not specified, masks are applied by channel. Used with Regional LoRA to control the number of regions that can be defined by colors in the mask.
- `--network_args`: Specifies additional arguments to pass to the network module in key=value format. For example: `--network_args "alpha=1.0,dropout=0.1"`.
- `--network_merge_n_models`: When using network merging, specifies the number of models to merge (instead of merging all loaded networks).
# Examples of Main Option Specifications
The following is an example of batch generating 64 images with the same prompt and a batch size of 4.
@@ -232,6 +266,22 @@ Please put spaces before and after the prompt option specification `--n`.
- `--am`: Specifies the weight of the additional network. Overrides the command line specification. If using multiple additional networks, specify them separated by __commas__, like `--am 0.8,0.5,0.3`.
- `--ow`: Specifies original_width for SDXL.
- `--oh`: Specifies original_height for SDXL.
- `--nw`: Specifies original_width_negative for SDXL.
- `--nh`: Specifies original_height_negative for SDXL.
- `--ct`: Specifies crop_top for SDXL.
- `--cl`: Specifies crop_left for SDXL.
- `--c`: Specifies the CLIP prompt.
- `--f`: Specifies the generated file name.
- `--glt`: Specifies the timestep to start increasing the size of the latent for Gradual Latent. Overrides the command line specification.
- `--glr`: Specifies the initial size of the latent for Gradual Latent as a ratio. Overrides the command line specification.
@@ -249,6 +299,21 @@ Example:
![image](https://user-images.githubusercontent.com/52813779/235343446-25654172-fff4-4aaf-977a-20d262b51676.png)
# Wildcards in Prompts (Dynamic Prompts)
Dynamic Prompts (Wildcard) notation is supported. While not exactly the same as the Web UI extension, the following features are available.
- `{A|B|C}` : Randomly selects one from A, B, or C.
- `{e$$A|B|C}` : Uses all of A, B, and C in order (enumeration). If there are multiple `{e$$...}` in the prompt, all combinations will be generated.
- Example: `{e$$red|blue} flower, {e$$1girl|2girls}` -> Generates 4 images: `red flower, 1girl`, `red flower, 2girls`, `blue flower, 1girl`, `blue flower, 2girls`.
- `{n$$A|B|C}` : Randomly selects n items from A, B, C and combines them.
- Example: `{2$$A|B|C}` -> `A, B` or `B, C`, etc.
- `{n-m$$A|B|C}` : Randomly selects between n and m items from A, B, C and combines them.
- `{$$sep$$A|B|C}` : Combines selected items with `sep` (default is `, `).
- Example: `{2$$ and $$A|B|C}` -> `A and B`, etc.
These can be used in combination.
# img2img
## Options
@@ -259,7 +324,7 @@ Example:
- `--sequential_file_name`: Specifies whether to make file names sequential. If specified, the generated file names will be sequential starting from `im_000001.png`.
- `--use_original_file_name`: If specified, the generated file name will be the same as the original file name.
- `--use_original_file_name`: If specified, the generated file name will be prepended with the original file name (for img2img mode).
- `--clip_vision_strength`: Enables CLIP Vision Conditioning for img2img with the specified strength. Uses the CLIP Vision model to enhance conditioning from the input image.
@@ -307,10 +372,6 @@ Specify the embeddings to use with the `--textual_inversion_embeddings` option (
As models, you can use Textual Inversion models trained with this repository and Textual Inversion models trained with Web UI (image embedding is not supported).
## Extended Textual Inversion
Specify the `--XTI_embeddings` option instead of `--textual_inversion_embeddings`. Usage is the same as `--textual_inversion_embeddings`.
## Highres. fix
This is a similar feature to the one in AUTOMATIC1111's Web UI (it may differ in various ways as it is an original implementation). It first generates a smaller image and then uses that image as a base for img2img to generate a large resolution image while preventing the entire image from collapsing.
@@ -375,6 +436,16 @@ These parameters can also be specified through prompt options:
- `--dsr`: Specifies Deep Shrink ratio from the prompt.
*Additional prompt options for Gradual Latent (requires `euler_a` sampler):*
- `--glt`: Specifies the timestep to start increasing the size of the latent for Gradual Latent. Overrides the command line specification.
- `--glr`: Specifies the initial size of the latent for Gradual Latent as a ratio. Overrides the command line specification.
- `--gls`: Specifies the ratio to increase the size of the latent for Gradual Latent. Overrides the command line specification.
- `--gle`: Specifies the interval to increase the size of the latent for Gradual Latent. Overrides the command line specification.
## ControlNet
Currently, only ControlNet 1.0 has been confirmed to work. Only Canny is supported for preprocessing.
@@ -440,70 +511,6 @@ It can also be combined with ControlNet (combination with ControlNet is recommen
If LoRA is specified, multiple LoRAs specified with `--network_weights` will correspond to each part of AND. As a current constraint, the number of LoRAs must be the same as the number of AND parts.
## CLIP Guided Stable Diffusion
The source code is copied and modified from [this custom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion) in Diffusers' Community Examples.
In addition to the normal prompt-based generation specification, it additionally acquires the text features of the prompt with a larger CLIP and controls the generated image so that the features of the image being generated approach those text features (this is my rough understanding). Since a larger CLIP is used, VRAM usage increases considerably (it may be difficult even for 512*512 with 8GB of VRAM), and generation time also increases.
Note that the selectable samplers are DDIM, PNDM, and LMS only.
Specify how much to reflect the CLIP features numerically with the `--clip_guidance_scale` option. In the previous sample, it is 100, so it seems good to start around there and increase or decrease it.
By default, the first 75 tokens of the prompt (excluding special weighting characters) are passed to CLIP. With the `--c` option in the prompt, you can specify the text to be passed to CLIP separately from the normal prompt (for example, it is thought that CLIP cannot recognize DreamBooth identifiers or model-specific words like "1girl", so text excluding them is considered good).
Command line example:
```batchfile
python gen_img.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1 \
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36 \
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1 \
--interactive --clip_guidance_scale 100
```
## CLIP Image Guided Stable Diffusion
This is a feature that passes another image to CLIP instead of text and controls generation to approach its features. Specify the numerical value of the application amount with the `--clip_image_guidance_scale` option and the image (file or folder) to use for guidance with the `--guide_image_path` option.
Command line example:
```batchfile
python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt\
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img \
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers \
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100 \
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
```
### VGG16 Guided Stable Diffusion
This is a feature that generates images to approach a specified image. In addition to the normal prompt-based generation specification, it additionally acquires the features of VGG16 and controls the generated image so that the image being generated approaches the specified guide image. It is recommended to use it with img2img (images tend to be blurred in normal generation). This is an original feature that reuses the mechanism of CLIP Guided Stable Diffusion. The idea is also borrowed from style transfer using VGG.
Note that the selectable samplers are DDIM, PNDM, and LMS only.
Specify how much to reflect the VGG16 features numerically with the `--vgg16_guidance_scale` option. From what I've tried, it seems good to start around 100 and increase or decrease it. Specify the image (file or folder) to use for guidance with the `--guide_image_path` option.
When batch converting multiple images with img2img and using the original images as guide images, it is OK to specify the same value for `--guide_image_path` and `--image_path`.
Command line example:
```batchfile
python gen_img.py --ckpt wd-v1-3-full-pruned-half.ckpt \
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img \
--xformers --sampler ddim --fp16 --W 512 --H 704 \
--batch_size 1 --images_per_prompt 1 \
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face \
--n lowres, bad anatomy, bad hands, error, missing fingers, \
cropped, worst quality, low quality, normal quality, \
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1" \
--strength 0.8 --image_path ..\\src_image\
--vgg16_guidance_scale 100 --guide_image_path ..\\src_image \
```
You can specify the VGG16 layer number used for feature acquisition with `--vgg16_guidance_layerP` (default is 20, which is ReLU of conv4-2). It is said that upper layers express style and lower layers express content.
![image](https://user-images.githubusercontent.com/52813779/235343813-3c1f0d7a-4fb3-4274-98e4-b92d76b551df.png)
# Other Options
- `--no_preview`: Does not display preview images in interactive mode. Specify this if OpenCV is not installed or if you want to check the output files directly.
@@ -536,25 +543,10 @@ Gradual Latent is a Hires fix that gradually increases the size of the latent.
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
- `--gradual_latent_s_noise`: Specifies the s_noise parameter for Gradual Latent. Default is 1.0.
- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). Values like `3,0.5,0.5,1` or `3,1.0,1.0,0` are recommended.
- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent in the format: ksize,sigma,strength,target-x (target-x: 1=True, 0=False). Recommended values: `3,0.5,0.5,1` or `3,1.0,1.0,0`.
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
It is more effective with SD 1.5. It is quite subtle with SDXL.
# Gradual Latent について (Japanese section - kept for reference)
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py``gen_img.py` に以下のオプションが追加されています。
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
それぞれのオプションは、プロンプトオプション、`--glt``--glr``--gls``--gle` でも指定できます。
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。

View File

@@ -0,0 +1,525 @@
Status: reviewed
# LoRA Training Guide for HunyuanImage-2.1 using `hunyuan_image_train_network.py` / `hunyuan_image_train_network.py` を用いたHunyuanImage-2.1モデルのLoRA学習ガイド
This document explains how to train LoRA models for the HunyuanImage-2.1 model using `hunyuan_image_train_network.py` included in the `sd-scripts` repository.
<details>
<summary>日本語</summary>
このドキュメントでは、`sd-scripts`リポジトリに含まれる`hunyuan_image_train_network.py`を使用して、HunyuanImage-2.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
</details>
## 1. Introduction / はじめに
`hunyuan_image_train_network.py` trains additional networks such as LoRA on the HunyuanImage-2.1 model, which uses a transformer-based architecture (DiT) different from Stable Diffusion. Two text encoders, Qwen2.5-VL and byT5, and a dedicated VAE are used.
This guide assumes you know the basics of LoRA training. For common options see [train_network.py](train_network.md) and [sdxl_train_network.py](sdxl_train_network.md).
**Prerequisites:**
* The repository is cloned and the Python environment is ready.
* A training dataset is prepared. See the dataset configuration guide.
<details>
<summary>日本語</summary>
`hunyuan_image_train_network.py`はHunyuanImage-2.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。HunyuanImage-2.1はStable Diffusionとは異なるDiT (Diffusion Transformer) アーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
**前提条件:**
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](config_README-ja.md)を参照してください)
</details>
## 2. Differences from `train_network.py` / `train_network.py` との違い
`hunyuan_image_train_network.py` is based on `train_network.py` but adapted for HunyuanImage-2.1. Main differences include:
* **Target model:** HunyuanImage-2.1 model.
* **Model structure:** HunyuanImage-2.1 uses a Transformer-based architecture (DiT). It uses two text encoders (Qwen2.5-VL and byT5) and a dedicated VAE.
* **Required arguments:** Additional arguments for the DiT model, Qwen2.5-VL, byT5, and VAE model files.
* **Incompatible options:** Some Stable Diffusion-specific arguments (e.g., `--v2`, `--clip_skip`, `--max_token_length`) are not used.
* **HunyuanImage-2.1-specific arguments:** Additional arguments for specific training parameters like flow matching.
<details>
<summary>日本語</summary>
`hunyuan_image_train_network.py``train_network.py`をベースに、HunyuanImage-2.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。
* **対象モデル:** HunyuanImage-2.1モデルを対象とします。
* **モデル構造:** HunyuanImage-2.1はDiTベースのアーキテクチャを持ちます。Text EncoderとしてQwen2.5-VLとbyT5の二つを使用し、専用のVAEを使用します。
* **必須の引数:** DiTモデル、Qwen2.5-VL、byT5、VAEの各モデルファイルを指定する引数が追加されています。
* **一部引数の非互換性:** Stable Diffusion向けの引数の一部例: `--v2`, `--clip_skip`, `--max_token_length`)は使用されません。
* **HunyuanImage-2.1特有の引数:** Flow Matchingなど、特有の学習パラメータを指定する引数が追加されています。
</details>
## 3. Preparation / 準備
Before starting training you need:
1. **Training script:** `hunyuan_image_train_network.py`
2. **HunyuanImage-2.1 DiT model file:** Base DiT model `.safetensors` file.
3. **Text Encoder model files:**
- Qwen2.5-VL model file (`--text_encoder`).
- byT5 model file (`--byt5`).
4. **VAE model file:** HunyuanImage-2.1-compatible VAE model `.safetensors` file (`--vae`).
5. **Dataset definition file (.toml):** TOML format file describing training dataset configuration.
### Downloading Required Models
To train HunyuanImage-2.1 models, you need to download the following model files:
- **DiT Model**: Download from the [Tencent HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1/) repository. Use `dit/hunyuanimage2.1.safetensors`.
- **Text Encoders and VAE**: Download from the [Comfy-Org/HunyuanImage_2.1_ComfyUI](https://huggingface.co/Comfy-Org/HunyuanImage_2.1_ComfyUI) repository:
- Qwen2.5-VL: `split_files/text_encoders/qwen_2.5_vl_7b.safetensors`
- byT5: `split_files/text_encoders/byt5_small_glyphxl_fp16.safetensors`
- VAE: `split_files/vae/hunyuan_image_2.1_vae_fp16.safetensors`
<details>
<summary>日本語</summary>
学習を開始する前に、以下のファイルが必要です。
1. **学習スクリプト:** `hunyuan_image_train_network.py`
2. **HunyuanImage-2.1 DiTモデルファイル:** 学習のベースとなるDiTモデルの`.safetensors`ファイル。
3. **Text Encoderモデルファイル:**
- Qwen2.5-VLモデルファイル (`--text_encoder`)。
- byT5モデルファイル (`--byt5`)。
4. **VAEモデルファイル:** HunyuanImage-2.1に対応するVAEモデルの`.safetensors`ファイル (`--vae`)。
5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。詳細は[データセット設定ガイド](config_README-ja.md)を参照してください)。
**必要なモデルのダウンロード**
HunyuanImage-2.1モデルを学習するためには、以下のモデルファイルをダウンロードする必要があります:
- **DiTモデル**: [Tencent HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1/) リポジトリから `dit/hunyuanimage2.1.safetensors` をダウンロードします。
- **Text EncoderとVAE**: [Comfy-Org/HunyuanImage_2.1_ComfyUI](https://huggingface.co/Comfy-Org/HunyuanImage_2.1_ComfyUI) リポジトリから以下をダウンロードします:
- Qwen2.5-VL: `split_files/text_encoders/qwen_2.5_vl_7b.safetensors`
- byT5: `split_files/text_encoders/byt5_small_glyphxl_fp16.safetensors`
- VAE: `split_files/vae/hunyuan_image_2.1_vae_fp16.safetensors`
</details>
## 4. Running the Training / 学習の実行
Run `hunyuan_image_train_network.py` from the terminal with HunyuanImage-2.1 specific arguments. Here's a basic command example:
```bash
accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py \
--pretrained_model_name_or_path="<path to HunyuanDiT model>" \
--text_encoder="<path to Qwen2.5-VL model>" \
--byt5="<path to byT5 model>" \
--vae="<path to VAE model>" \
--dataset_config="my_hunyuan_dataset_config.toml" \
--output_dir="<output directory>" \
--output_name="my_hunyuan_lora" \
--save_model_as=safetensors \
--network_module=networks.lora_hunyuan_image \
--network_dim=16 \
--network_alpha=1 \
--network_train_unet_only \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--attn_mode="torch" \
--split_attn \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
--gradient_checkpointing \
--model_prediction_type="raw" \
--discrete_flow_shift=5.0 \
--blocks_to_swap=18 \
--cache_text_encoder_outputs \
--cache_latents
```
**HunyuanImage-2.1 training does not support LoRA modules for Text Encoders, so `--network_train_unet_only` is required.**
<details>
<summary>日本語</summary>
学習は、ターミナルから`hunyuan_image_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、HunyuanImage-2.1特有の引数を指定する必要があります。
コマンドラインの例は英語のドキュメントを参照してください。
</details>
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
The script adds HunyuanImage-2.1 specific arguments. For common arguments (like `--output_dir`, `--output_name`, `--network_module`, etc.), see the [`train_network.py` guide](train_network.md).
#### Model-related [Required]
* `--pretrained_model_name_or_path="<path to HunyuanDiT model>"` **[Required]**
- Specifies the path to the base DiT model `.safetensors` file.
* `--text_encoder="<path to Qwen2.5-VL model>"` **[Required]**
- Specifies the path to the Qwen2.5-VL Text Encoder model file. Should be `bfloat16`.
* `--byt5="<path to byT5 model>"` **[Required]**
- Specifies the path to the byT5 Text Encoder model file. Should be `float16`.
* `--vae="<path to VAE model>"` **[Required]**
- Specifies the path to the HunyuanImage-2.1-compatible VAE model `.safetensors` file.
#### HunyuanImage-2.1 Training Parameters
* `--network_train_unet_only` **[Required]**
- Specifies that only the DiT model will be trained. LoRA modules for Text Encoders are not supported.
* `--discrete_flow_shift=<float>`
- Specifies the shift value for the scheduler used in Flow Matching. Default is `5.0`.
* `--model_prediction_type=<choice>`
- Specifies what the model predicts. Choose from `raw`, `additive`, `sigma_scaled`. Default and recommended is `raw`.
* `--timestep_sampling=<choice>`
- Specifies the sampling method for timesteps (noise levels) during training. Choose from `sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`. Default is `sigma`.
* `--sigmoid_scale=<float>`
- Scale factor when `timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default is `1.0`.
#### Memory/Speed Related
* `--attn_mode=<choice>`
- Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1.
* `--split_attn`
- Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1.
* `--fp8_scaled`
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option. See [Musubi Tuner's documentation](https://github.com/kohya-ss/musubi-tuner/blob/main/docs/advanced_config.md#fp8-weight-optimization-for-models--%E3%83%A2%E3%83%87%E3%83%AB%E3%81%AE%E9%87%8D%E3%81%BF%E3%81%AEfp8%E3%81%B8%E3%81%AE%E6%9C%80%E9%81%A9%E5%8C%96) for details.
* `--fp8_vl`
- Use FP8 for the VLM (Qwen2.5-VL) text encoder.
* `--text_encoder_cpu`
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. **In addition, increasing `--num_cpu_threads_per_process` in the `accelerate launch` command, like `--num_cpu_threads_per_process=8` or `16`, can speed up encoding in some environments.**
* `--blocks_to_swap=<integer>` **[Experimental Feature]**
- Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`.
* `--cache_text_encoder_outputs`
- Caches the outputs of Qwen2.5-VL and byT5. This reduces memory usage.
* `--cache_latents`, `--cache_latents_to_disk`
- Caches the outputs of VAE. Similar functionality to [sdxl_train_network.py](sdxl_train_network.md).
* `--vae_chunk_size=<integer>`
- Enables chunked processing in the VAE to reduce VRAM usage during encoding and decoding. Specify the chunk size as an integer (e.g., `16`). Larger values use more VRAM but are faster. Default is `None` (no chunking). This option is useful when VRAM is limited (e.g., 8GB or 12GB).
<details>
<summary>日本語</summary>
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のHunyuanImage-2.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。
コマンドラインの例と詳細な引数の説明は英語のドキュメントを参照してください。
</details>
## 5. Using the Trained Model / 学習済みモデルの利用
After training, a LoRA model file is saved in `output_dir` and can be used in inference environments supporting HunyuanImage-2.1.
<details>
<summary>日本語</summary>
学習が完了すると、指定した`output_dir`にLoRAモデルファイル例: `my_hunyuan_lora.safetensors`が保存されます。このファイルは、HunyuanImage-2.1モデルに対応した推論環境で使用できます。
</details>
## 6. Advanced Settings / 高度な設定
### 6.1. VRAM Usage Optimization / VRAM使用量の最適化
HunyuanImage-2.1 is a large model, so GPUs without sufficient VRAM require optimization.
#### Recommended Settings by GPU Memory
Based on testing with the pull request, here are recommended VRAM optimization settings:
| GPU Memory | Recommended Settings |
|------------|---------------------|
| 40GB+ VRAM | Standard settings (no special optimization needed) |
| 24GB VRAM | `--fp8_scaled --blocks_to_swap 9` |
| 12GB VRAM | `--fp8_scaled --blocks_to_swap 32` |
| 8GB VRAM | `--fp8_scaled --blocks_to_swap 37` |
#### Key VRAM Reduction Options
- **`--fp8_scaled`**: Enables training the DiT in scaled FP8 format. This is the recommended FP8 option for HunyuanImage-2.1, replacing the unsupported `--fp8_base` option. Essential for <40GB VRAM environments.
- **`--fp8_vl`**: Use FP8 for the VLM (Qwen2.5-VL) text encoder.
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. Up to 37 blocks can be swapped for HunyuanImage-2.1.
- **`--cpu_offload_checkpointing`**: Offloads gradient checkpoints to CPU. Can reduce VRAM usage but decreases training speed. Cannot be used with `--blocks_to_swap`.
- **Using Adafactor optimizer**: Can reduce VRAM usage more than 8bit AdamW:
```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```
<details>
<summary>日本語</summary>
HunyuanImage-2.1は大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。
#### GPU別推奨設定
Pull Requestのテスト結果に基づく推奨VRAM最適化設定
| GPU Memory | 推奨設定 |
|------------|---------|
| 40GB+ VRAM | 標準設定(特別な最適化不要) |
| 24GB VRAM | `--fp8_scaled --blocks_to_swap 9` |
| 12GB VRAM | `--fp8_scaled --blocks_to_swap 32` |
| 8GB VRAM | `--fp8_scaled --blocks_to_swap 37` |
主要なVRAM削減オプション
- `--fp8_scaled`: DiTをスケールされたFP8形式で学習推奨されるFP8オプション、40GB VRAM未満の環境では必須
- `--fp8_vl`: VLMテキストエンコーダにFP8を使用
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ最大37ブロック
- `--cpu_offload_checkpointing`: 勾配チェックポイントをCPUにオフロード
- Adafactorオプティマイザの使用
</details>
### 6.2. Important HunyuanImage-2.1 LoRA Training Settings / HunyuanImage-2.1 LoRA学習の重要な設定
HunyuanImage-2.1 training has several settings that can be specified with arguments:
#### Timestep Sampling Methods
The `--timestep_sampling` option specifies how timesteps (0-1) are sampled:
- `sigma`: Sigma-based like SD3 (Default)
- `uniform`: Uniform random
- `sigmoid`: Sigmoid of normal distribution random
- `shift`: Sigmoid value of normal distribution random with shift.
- `flux_shift`: Shift sigmoid value of normal distribution random according to resolution.
#### Model Prediction Processing
The `--model_prediction_type` option specifies how to interpret and process model predictions:
- `raw`: Use as-is **[Recommended, Default]**
- `additive`: Add to noise input
- `sigma_scaled`: Apply sigma scaling
#### Recommended Settings
Based on experiments, the default settings work well:
```
--model_prediction_type raw --discrete_flow_shift 5.0
```
<details>
<summary>日本語</summary>
HunyuanImage-2.1の学習には、引数で指定できるいくつかの設定があります。詳細な説明とコマンドラインの例は英語のドキュメントを参照してください。
主要な設定オプション:
- タイムステップのサンプリング方法(`--timestep_sampling`
- モデル予測の処理方法(`--model_prediction_type`
- 推奨設定の組み合わせ
</details>
### 6.3. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定
You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control.
These settings are specified via the `network_args` argument.
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
* Example: `--network_args "network_reg_dims=attn.*.q_proj=4,attn.*.k_proj=4"`
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
* Example: `--network_args "network_reg_lrs=down_blocks.1=1e-4,up_blocks.2=2e-4"`
**Notes:**
* To find the correct module names for the patterns, you may need to inspect the model structure.
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied.
<details>
<summary>日本語</summary>
正規表現を用いて、LoRAのモジュールごとにランクdimや学習率を指定することができます。これにより、柔軟できめ細やかな制御が可能になります。
これらの設定は `network_args` 引数で指定します。
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。
**注意点:**
* パターンのための正確なモジュール名を見つけるには、モデルの構造を調べる必要があるかもしれません。
* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。
* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。
</details>
### 6.4. Multi-Resolution Training / マルチ解像度トレーニング
You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution.
**Note:** This feature is available, but it is **not recommended** as the HunyuanImage-2.1 base model was not trained with multi-resolution capabilities. Using it may lead to unexpected results.
Configuration file example:
```toml
[general]
shuffle_caption = true
caption_extension = ".txt"
[[datasets]]
batch_size = 2
enable_bucket = true
resolution = [1024, 1024]
[[datasets.subsets]]
image_dir = "path/to/image/directory"
num_repeats = 1
[[datasets]]
batch_size = 1
enable_bucket = true
resolution = [1280, 768]
[[datasets.subsets]]
image_dir = "path/to/another/directory"
num_repeats = 1
```
<details>
<summary>日本語</summary>
データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。
**注意:** この機能は利用可能ですが、HunyuanImage-2.1のベースモデルはマルチ解像度で学習されていないため、**非推奨**です。使用すると予期しない結果になる可能性があります。
設定ファイルの例は英語のドキュメントを参照してください。
</details>
### 6.5. Validation / 検証
You can calculate validation loss during training using a validation dataset to evaluate model generalization performance. This feature works the same as in other training scripts. For details, please refer to the [Validation Guide](validation.md).
<details>
<summary>日本語</summary>
学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。この機能は他の学習スクリプトと同様に動作します。詳細は[検証ガイド](validation.md)を参照してください。
</details>
## 7. Other Training Options / その他の学習オプション
- **`--ip_noise_gamma`**: Use `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` to adjust Input Perturbation noise gamma values during training. See Stable Diffusion 3 training options for details.
- **`--loss_type`**: Specifies the loss function for training. The default is `l2`.
- `l1`: L1 loss.
- `l2`: L2 loss (mean squared error).
- `huber`: Huber loss.
- `smooth_l1`: Smooth L1 loss.
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: These are parameters for Huber loss. They are used when `--loss_type` is `huber` or `smooth_l1`.
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: These options allow you to adjust the loss weighting for each timestep. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage.
<details>
<summary>日本語</summary>
- **`--ip_noise_gamma`**: Input Perturbationイズのガンマ値を調整します。
- **`--loss_type`**: 学習に用いる損失関数を指定します。
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Huber損失のパラメータです。
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: 各タイムステップの損失の重み付けを調整します。
- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップを融合してVRAM使用量を削減します。
</details>
## 8. Using the Inference Script / 推論スクリプトの使用法
The `hunyuan_image_minimal_inference.py` script allows you to generate images using trained LoRA models. Here's a basic usage example:
```bash
python hunyuan_image_minimal_inference.py \
--dit "<path to hunyuanimage2.1.safetensors>" \
--text_encoder "<path to qwen_2.5_vl_7b.safetensors>" \
--byt5 "<path to byt5_small_glyphxl_fp16.safetensors>" \
--vae "<path to hunyuan_image_2.1_vae_fp16.safetensors>" \
--lora_weight "<path to your trained LoRA>" \
--lora_multiplier 1.0 \
--attn_mode "torch" \
--prompt "A cute cartoon penguin in a snowy landscape" \
--image_size 2048 2048 \
--infer_steps 50 \
--guidance_scale 3.5 \
--flow_shift 5.0 \
--seed 542017 \
--save_path "output_image.png"
```
**Key Options:**
- `--fp8_scaled`: Use scaled FP8 format for reduced VRAM usage during inference
- `--blocks_to_swap`: Swap blocks to CPU to reduce VRAM usage
- `--image_size`: Resolution in **height width** (inference is most stable at 2560x1536, 2304x1792, 2048x2048, 1792x2304, 1536x2560 according to the official repo)
- `--guidance_scale`: CFG scale (default: 3.5)
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
- `--text_encoder_cpu`: Run the text encoders on CPU to reduce VRAM usage
- `--vae_chunk_size`: Chunk size for VAE decoding to reduce memory usage (default: None, no chunking). 16 is recommended if enabled.
- `--apg_start_step_general` and `--apg_start_step_ocr`: Start steps for APG (Adaptive Projected Guidance) if using APG during inference. `5` and `38` are the official recommended values for 50 steps. If this value exceeds `--infer_steps`, APG will not be applied.
- `--guidance_rescale`: Rescales the guidance for steps before APG starts. Default is `0.0` (no rescaling). If you use this option, a value around `0.5` might be good starting point.
- `--guidance_rescale_apg`: Rescales the guidance for APG. Default is `0.0` (no rescaling). This option doesn't seem to have a large effect, but if you use it, a value around `0.5` might be a good starting point.
`--split_attn` is not supported (since inference is done one at a time). `--fp8_vl` is not supported, please use CPU for the text encoder if VRAM is insufficient.
<details>
<summary>日本語</summary>
`hunyuan_image_minimal_inference.py`スクリプトを使用して、学習したLoRAモデルで画像を生成できます。基本的な使用例は英語のドキュメントを参照してください。
**主要なオプション:**
- `--fp8_scaled`: VRAM使用量削減のためのスケールFP8形式
- `--blocks_to_swap`: VRAM使用量削減のためのブロックスワップ
- `--image_size`: 解像度2048x2048で最も安定
- `--guidance_scale`: CFGスケール推奨: 3.5
- `--flow_shift`: Flow Matchingシフトパラメータデフォルト: 5.0
- `--text_encoder_cpu`: テキストエンコーダをCPUで実行してVRAM使用量削減
- `--vae_chunk_size`: VAEデコーディングのチャンクサイズデフォルト: None、チャンク処理なし。有効にする場合は16を推奨。
- `--apg_start_step_general` と `--apg_start_step_ocr`: 推論中にAPGを使用する場合の開始ステップ。50ステップの場合、公式推奨値はそれぞれ5と38です。この値が`--infer_steps`を超えると、APGは適用されません。
- `--guidance_rescale`: APG開始前のステップに対するガイダンスのリスケーリング。デフォルトは0.0リスケーリングなし。使用する場合、0.5程度から始めて調整してください。
- `--guidance_rescale_apg`: APGに対するガイダンスのリスケーリング。デフォルトは0.0リスケーリングなし。このオプションは大きな効果はないようですが、使用する場合は0.5程度から始めて調整してください。
`--split_attn`はサポートされていません1件ずつ推論するため。`--fp8_vl`もサポートされていません。VRAMが不足する場合はテキストエンコーダをCPUで実行してください。
</details>
## 9. Related Tools / 関連ツール
### `networks/convert_hunyuan_image_lora_to_comfy.py`
A script to convert LoRA models to ComfyUI-compatible format. The formats differ slightly, so conversion is necessary. You can convert from the sd-scripts format to ComfyUI format with:
```bash
python networks/convert_hunyuan_image_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
```
Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
<details>
<summary>日本語</summary>
**`networks/convert_hunyuan_image_lora_to_comfy.py`**
LoRAモデルをComfyUI互換形式に変換するスクリプト。わずかに形式が異なるため、変換が必要です。以下の指定で、sd-scriptsの形式からComfyUI形式に変換できます。
```bash
python networks/convert_hunyuan_image_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
```
`--reverse`オプションを付けると、逆変換ComfyUI形式からsd-scripts形式も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
</details>
## 10. Others / その他
`hunyuan_image_train_network.py` includes many features common with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these features, refer to the [`train_network.py` guide](train_network.md#5-other-features--その他の機能) or the script help (`python hunyuan_image_train_network.py --help`).
<details>
<summary>日本語</summary>
`hunyuan_image_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python hunyuan_image_train_network.py --help`) を参照してください。
</details>

View File

@@ -170,6 +170,8 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--model_prediction_type=<choice>` Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`**
* `--system_prompt=<string>` System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
* `--use_flash_attn` Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training.
* `--use_sage_attn` Use Sage Attention for the model.
* `--sample_batch_size=<integer>` Batch size to use for sampling, defaults to `--training_batch_size` value. Sample batches are bucketed by width, height, guidance scale, and seed.
* `--sigmoid_scale=<float>` Scale factor for sigmoid timestep sampling. Default `1.0`.
#### Memory and Speed / メモリ・速度関連
@@ -216,6 +218,8 @@ For Lumina Image 2.0, you can specify different dimensions for various component
* `--model_prediction_type=<choice>` モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`**
* `--system_prompt=<string>` 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
* `--use_flash_attn` Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。
* `--use_sage_attn` Sage Attentionを使用します。
* `--sample_batch_size=<integer>` サンプリングに使用するバッチサイズ。デフォルトは `--training_batch_size` の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます。
* `--sigmoid_scale=<float>` sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。
#### メモリ・速度関連

View File

@@ -1,5 +1,3 @@
Status: reviewed
# LoRA Training Guide for Stable Diffusion 3/3.5 using `sd3_train_network.py` / `sd3_train_network.py` を用いたStable Diffusion 3/3.5モデルのLoRA学習ガイド
This document explains how to train LoRA (Low-Rank Adaptation) models for Stable Diffusion 3 (SD3) and Stable Diffusion 3.5 (SD3.5) using `sd3_train_network.py` in the `sd-scripts` repository.
@@ -18,7 +16,6 @@ This guide assumes you already understand the basics of LoRA training. For commo
<details>
<summary>日本語</summary>
ステータス:内容を一通り確認した
`sd3_train_network.py`は、Stable Diffusion 3/3.5モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。SD3は、MMDiT (Multi-Modal Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。このスクリプトを使用することで、SD3/3.5モデルに特化したLoRAモデルを作成できます。
@@ -98,7 +95,7 @@ accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py \
--save_every_n_epochs=1 \
--mixed_precision="fp16" \
--gradient_checkpointing \
--weighting_scheme="sigma_sqrt" \
--weighting_scheme="uniform" \
--blocks_to_swap=32
```
@@ -106,6 +103,7 @@ accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py \
<details>
<summary>日本語</summary>
学習は、ターミナルから`sd3_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、SD3/3.5特有の引数を指定する必要があります。
以下に、基本的なコマンドライン実行例を示します。
@@ -131,11 +129,12 @@ accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py
--save_every_n_epochs=1
--mixed_precision="fp16"
--gradient_checkpointing
--weighting_scheme="sigma_sqrt"
--weighting_scheme="uniform"
--blocks_to_swap=32
```
※実際には1行で書くか、適切な改行文字`\` または `^`)を使用してください。
</details>
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
@@ -157,11 +156,19 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--enable_scaled_pos_embed` **[SD3.5][experimental]** Scale positional embeddings when training with multiple resolutions.
* `--training_shift=<float>` Shift applied to the timestep distribution. Default `1.0`.
* `--weighting_scheme=<choice>` Weighting method for loss by timestep. Default `uniform`.
* `--logit_mean`, `--logit_std`, `--mode_scale` Parameters for `logit_normal` or `mode` weighting.
* `--logit_mean=<float>` Mean value for `logit_normal` weighting scheme. Default `0.0`.
* `--logit_std=<float>` Standard deviation for `logit_normal` weighting scheme. Default `1.0`.
* `--mode_scale=<float>` Scale factor for `mode` weighting scheme. Default `1.29`.
#### Memory and Speed / メモリ・速度関連
* `--blocks_to_swap=<integer>` **[experimental]** Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`.
* `--cache_text_encoder_outputs` Caches the outputs of the text encoders to reduce VRAM usage and speed up training. This is particularly effective for SD3, which uses three text encoders. Recommended when not training the text encoder LoRA. For more details, see the [`sdxl_train_network.py` guide](sdxl_train_network.md).
* `--cache_text_encoder_outputs_to_disk` Caches the text encoder outputs to disk when the above option is enabled.
* `--t5xxl_device=<device>` **[not supported yet]** Specifies the device for T5-XXL model. If not specified, uses accelerator's device.
* `--t5xxl_dtype=<dtype>` **[not supported yet]** Specifies the dtype for T5-XXL model. If not specified, uses default dtype from mixed precision.
* `--save_clip` **[not supported yet]** Saves CLIP models to checkpoint (unified checkpoint format not yet supported).
* `--save_t5xxl` **[not supported yet]** Saves T5-XXL model to checkpoint (unified checkpoint format not yet supported).
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
@@ -169,6 +176,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
<details>
<summary>日本語</summary>
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のSD3/3.5特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
#### モデル関連
@@ -189,34 +197,159 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--enable_scaled_pos_embed` **[SD3.5向け][実験的機能]** マルチ解像度学習時に解像度に応じてPositional Embeddingをスケーリングします。
* `--training_shift=<float>` タイムステップ分布を調整するためのシフト値です。デフォルトは`1.0`です。
* `--weighting_scheme=<choice>` タイムステップに応じた損失の重み付け方法を指定します。デフォルトは`uniform`です。
* `--logit_mean`, `--logit_std`, `--mode_scale` `logit_normal`または`mode`使用時のパラメータです。
* `--logit_mean=<float>` `logit_normal`重み付けスキームの平均値です。デフォルトは`0.0`です。
* `--logit_std=<float>` `logit_normal`重み付けスキームの標準偏差です。デフォルトは`1.0`です。
* `--mode_scale=<float>` `mode`重み付けスキームのスケール係数です。デフォルトは`1.29`です。
#### メモリ・速度関連
* `--blocks_to_swap=<integer>` **[実験的機能]** TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。
* `--cache_text_encoder_outputs` Text Encoderの出力をキャッシュし、VRAM使用量削減と学習高速化を図ります。SD3は3つのText Encoderを持つため特に効果的です。Text EncoderのLoRAを学習しない場合に推奨されます。詳細は[`sdxl_train_network.py`のガイド](sdxl_train_network.md)を参照してください。
* `--cache_text_encoder_outputs_to_disk` 上記オプションと併用し、Text Encoderの出力をディスクにキャッシュします。
* `--t5xxl_device=<device>` **[未サポート]** T5-XXLモデルのデバイスを指定します。指定しない場合はacceleratorのデバイスを使用します。
* `--t5xxl_dtype=<dtype>` **[未サポート]** T5-XXLモデルのdtypeを指定します。指定しない場合はデフォルトのdtypemixed precisionからを使用します。
* `--save_clip` **[未サポート]** CLIPモデルをチェックポイントに保存します統合チェックポイント形式は未サポート
* `--save_t5xxl` **[未サポート]** T5-XXLモデルをチェックポイントに保存します統合チェックポイント形式は未サポート
#### 非互換・非推奨の引数
* `--v2`, `--v_parameterization`, `--clip_skip` Stable Diffusion v1/v2向けの引数のため、SD3/3.5学習では使用されません。
</details>
### 4.2. Starting Training / 学習の開始
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
## 5. Using the Trained Model / 学習済みモデルの利用
<details>
<summary>日本語</summary>
必要な引数を設定したら、コマンドを実行して学習を開始します。全体の流れやログの確認方法は、[train_network.pyのガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
</details>
## 5. LoRA Target Modules / LoRAの学習対象モジュール
When training LoRA with `sd3_train_network.py`, the following modules are targeted by default:
* **MMDiT (replaces U-Net)**:
* `qkv` (Query, Key, Value) matrices and `proj_out` (output projection) in the attention blocks.
* **final_layer**:
* The output layer at the end of MMDiT.
By using `--network_args`, you can apply more detailed controls, such as setting different ranks (dimensions) for each module.
### Specify rank for each layer in SD3 LoRA / 各層のランクを指定する
You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.
When network_args is not specified, the default value (`network_dim`) is applied, same as before.
|network_args|target layer|
|---|---|
|context_attn_dim|attn in context_block|
|context_mlp_dim|mlp in context_block|
|context_mod_dim|adaLN_modulation in context_block|
|x_attn_dim|attn in x_block|
|x_mlp_dim|mlp in x_block|
|x_mod_dim|adaLN_modulation in x_block|
`"verbose=True"` is also available for debugging. It shows the rank of each layer.
example:
```
--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True"
```
You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list.
example:
```
--network_args "emb_dims=[2,3,4,5,6,7]"
```
Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`.
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`.
### Specify blocks to train in SD3 LoRA training
You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`.
The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks.
example:
```
--network_args "train_block_indices=1,2,6-8"
```
<details>
<summary>日本語</summary>
`sd3_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
* **MMDiT (U-Netの代替)**:
* Attentionブロック内の`qkv`Query, Key, Value行列と、`proj_out`出力Projection
* **final_layer**:
* MMDiTの最後にある出力層。
`--network_args` を使用することで、モジュールごとに異なるランク(次元数)を設定するなど、より詳細な制御が可能です。
### SD3 LoRAで各層のランクを指定する
各層のランクを指定するには、`--network_args`オプションを使用します。`0`を指定すると、その層にはLoRAが適用されません。
network_argsが指定されない場合、デフォルト値`network_dim`)が適用されます。
|network_args|target layer|
|---|---|
|context_attn_dim|attn in context_block|
|context_mlp_dim|mlp in context_block|
|context_mod_dim|adaLN_modulation in context_block|
|x_attn_dim|attn in x_block|
|x_mlp_dim|mlp in x_block|
|x_mod_dim|adaLN_modulation in x_block|
`"verbose=True"`を指定すると、各層のランクが表示されます。
例:
```bash
--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True"
```
また、`emb_dims`を指定することで、SD3の条件付け層にLoRAを適用することもできます。指定する際は、必ず`[]`内にカンマ区切りで6つの数字を指定してください。
```bash
--network_args "emb_dims=[2,3,4,5,6,7]"
```
各数字は、`context_embedder``t_embedder``x_embedder``y_embedder``final_layer_adaLN_modulation``final_layer_linear`に対応しています。上記の例では、すべての条件付け層にLoRAを適用し、`context_embedder`に2、`t_embedder`に3、`x_embedder`に4、`y_embedder`に5、`final_layer_adaLN_modulation`に6、`final_layer_linear`に7のランクを設定しています。
`0`を指定すると、その層にはLoRAが適用されません。例えば、`[4,0,0,4,0,0]`と指定すると、`context_embedder``y_embedder`のみにLoRAが適用されます。
</details>
## 6. Using the Trained Model / 学習済みモデルの利用
When training finishes, a LoRA model file (e.g. `my_sd3_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support SD3/3.5, such as ComfyUI.
## 6. Others / その他
<details>
<summary>日本語</summary>
学習が完了すると、指定した`output_dir`にLoRAモデルファイル例: `my_sd3_lora.safetensors`が保存されます。このファイルは、SD3/3.5モデルに対応した推論環境(例: ComfyUIなどで使用できます。
</details>
## 7. Others / その他
`sd3_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python sd3_train_network.py --help`.
<details>
<summary>日本語</summary>
必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
学習が完了すると、指定した`output_dir`にLoRAモデルファイル例: `my_sd3_lora.safetensors`が保存されます。このファイルは、SD3/3.5モデルに対応した推論環境(例: ComfyUIなどで使用できます。
`sd3_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python sd3_train_network.py --help`) を参照してください。
</details>

View File

@@ -42,7 +42,7 @@ Before starting training, you will need the following files:
The dataset definition file (`.toml`) contains detailed settings such as the directory of images to use, repetition count, caption settings, resolution buckets (optional), etc.
For more details on how to write the dataset definition file, please refer to the [Dataset Configuration Guide](link/to/dataset/config/doc).
For more details on how to write the dataset definition file, please refer to the [Dataset Configuration Guide](./config_README-en.md).
In this guide, we will use a file named `my_dataset_config.toml` as an example.
@@ -56,9 +56,9 @@ In this guide, we will use a file named `my_dataset_config.toml` as an example.
**データセット定義ファイルについて**
データセット定義ファイル (`.toml`) には、使用する画像のディレクトリ、繰り返し回数、キャプションの設定、解像度バケツ(任意)などの詳細な設定を記述します。
データセット定義ファイル (`.toml`) には、使用する画像のディレクトリ、繰り返し回数、キャプションの設定、Aspect Ratio Bucketing(任意)などの詳細な設定を記述します。
データセット定義ファイルの詳しい書き方については、[データセット設定ガイド](link/to/dataset/config/doc)を参照してください。
データセット定義ファイルの詳しい書き方については、[データセット設定ガイド](./config_README-ja.md)を参照してください。
ここでは、例として `my_dataset_config.toml` という名前のファイルを使用することにします。
</details>
@@ -143,6 +143,16 @@ Next, we'll explain the main command-line arguments.
* Specifies the rank (dimension) of LoRA. Higher values increase expressiveness but also increase file size and computational cost. Values between 4 and 128 are commonly used. There is no default (module dependent).
* `--network_alpha=1`
* Specifies the alpha value for LoRA. This parameter is related to learning rate scaling. It is generally recommended to set it to about half the value of `network_dim`, but it can also be the same value as `network_dim`. The default is 1. Setting it to the same value as `network_dim` will result in behavior similar to older versions.
* `--network_args`
* Used to specify additional parameters specific to the LoRA module. For example, to use Conv2d (3x3) LoRA (LoRA-C3Lier), specify the following in `--network_args`. Use `conv_dim` to specify the rank for Conv2d (3x3) and `conv_alpha` for alpha.
```
--network_args "conv_dim=4" "conv_alpha=1"
```
If alpha is omitted as shown below, it defaults to 1.
```
--network_args "conv_dim=4"
```
#### Training Parameters / 学習パラメータ
@@ -222,6 +232,16 @@ Next, we'll explain the main command-line arguments.
* `--network_alpha=1`
* LoRA のアルファ値 (alpha) を指定します。学習率のスケーリングに関係するパラメータで、一般的には `network_dim` の半分程度の値を指定することが推奨されますが、`network_dim` と同じ値を指定する場合もあります。デフォルトは 1 です。`network_dim` と同じ値に設定すると、旧バージョンと同様の挙動になります。
* `--network_args`
* LoRA モジュールに特有の追加パラメータを指定するために使用します。例えば、Conv2d (3x3) の LoRA (LoRA-C3Lier) を使用する場合は`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定します。
```
--network_args "conv_dim=4" "conv_alpha=1"
```
以下のように alpha を省略した時は1になります。
```
--network_args "conv_dim=4"
```
#### 学習パラメータ
* `--learning_rate=1e-4`
@@ -311,4 +331,37 @@ For these features, please refer to the script's help (`python train_network.py
* ネットワークの追加設定 (`--network_args` など)
これらの機能については、スクリプトのヘルプ (`python train_network.py --help`) やリポジトリ内の他のドキュメントを参照してください。
</details>
## 6. Additional Information / 追加情報
### Naming of LoRA
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
LoRA for Linear layers and Conv2d layers with 1x1 kernel
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
<details>
<summary>日本語</summary>
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
</details>

View File

@@ -1,5 +1,3 @@
Status: under review
# Advanced Settings: Detailed Guide for SDXL LoRA Training Script `sdxl_train_network.py` / 高度な設定: SDXL LoRA学習スクリプト `sdxl_train_network.py` 詳細ガイド
This document describes the advanced options available when training LoRA models for SDXL (Stable Diffusion XL) with `sdxl_train_network.py` in the `sd-scripts` repository. For the basics, please read [How to Use the LoRA Training Script `train_network.py`](train_network.md) and [How to Use the SDXL LoRA Training Script `sdxl_train_network.py`](sdxl_train_network.md).
@@ -102,9 +100,33 @@ Basic options are common with `train_network.py`.
* `--sample_every_n_steps=N` / `--sample_every_n_epochs=N`: Generates sample images every N steps/epochs.
* `--sample_at_first`: Generates sample images before training starts.
* `--sample_prompts=\"<prompt file>\"`: Specifies a file (`.txt`, `.toml`, `.json`) containing prompts for sample image generation. Format follows [gen_img_diffusers.py](gen_img_diffusers.py). See [documentation](gen_img_README-ja.md) for details.
* `--sample_prompts=\"<prompt file>\"`: Specifies a file (`.txt`, `.toml`, `.json`) containing prompts for sample image generation.
* `--sample_sampler=\"...\"`: Specifies the sampler (scheduler) for sample image generation. `euler_a`, `dpm++_2m_karras`, etc., are common. See `--help` for choices.
#### Format of Prompt File
A prompt file can contain multiple prompts with options, for example:
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG.
* `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models.
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working for SD/SDXL models, not working for other models like FLUX.1.
### 1.8. Logging & Tracking
* `--logging_dir=\"<log directory>\"`: Specifies the directory for TensorBoard and other logs. If not specified, logs are not output.
@@ -130,15 +152,61 @@ Basic options are common with `train_network.py`.
* `--huber_c=C` / `--huber_scale=S`: Parameters for `huber` or `smooth_l1` loss.
* `--masked_loss`: Limits loss calculation area based on a mask image. Requires specifying mask images (black and white) in `conditioning_data_dir` in dataset settings. See [About Masked Loss](masked_loss_README.md) for details.
### 1.10. Distributed Training and Others
### 1.10. Distributed Training and Other Training Related Options
* `--seed=N`: Specifies the random seed. Set this to ensure training reproducibility.
* `--max_token_length=N` (`75`, `150`, `225`): Maximum token length processed by Text Encoders. For SDXL, typically `75` (default), `150`, or `225`. Longer lengths can handle more complex prompts but increase VRAM usage.
* `--clip_skip=N`: Uses the output from N layers skipped from the final layer of Text Encoders. **Not typically used for SDXL**.
* `--lowram` / `--highvram`: Options for memory usage optimization. `--lowram` is for environments like Colab where RAM < VRAM, `--highvram` is for environments with ample VRAM.
* `--persistent_data_loader_workers` / `--max_data_loader_n_workers=N`: Settings for DataLoader worker processes. Affects wait time between epochs and memory usage.
* `--config_file=\"<config file>\"` / `--output_config`: Options to use/output a `.toml` file instead of command line arguments.
* `--config_file="<config file>"` / `--output_config`: Options to use/output a `.toml` file instead of command line arguments.
* **Accelerate/DeepSpeed related:** (`--ddp_timeout`, `--ddp_gradient_as_bucket_view`, `--ddp_static_graph`): Detailed settings for distributed training. Accelerate settings (`accelerate config`) are usually sufficient. DeepSpeed requires separate configuration.
* `--initial_epoch=<integer>` Sets the initial epoch number. `1` means first epoch (same as not specifying). Note: `initial_epoch`/`initial_step` doesn't affect the lr scheduler, which means lr scheduler will start from 0 without `--resume`.
* `--initial_step=<integer>` Sets the initial step number including all epochs. `0` means first step (same as not specifying). Overwrites `initial_epoch`.
* `--skip_until_initial_step` Skips training until `initial_step` is reached.
### 1.11. Console and Logging / コンソールとログ
* `--console_log_level`: Sets the logging level for the console output. Choose from `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`.
* `--console_log_file`: Redirects console logs to a specified file.
* `--console_log_simple`: Enables a simpler log format.
### 1.12. Hugging Face Hub Integration / Hugging Face Hub 連携
* `--huggingface_repo_id`: The repository name on Hugging Face Hub to upload the model to (e.g., `your-username/your-model`).
* `--huggingface_repo_type`: The type of repository on Hugging Face Hub. Usually `model`.
* `--huggingface_path_in_repo`: The path within the repository to upload files to.
* `--huggingface_token`: Your Hugging Face Hub authentication token.
* `--huggingface_repo_visibility`: Sets the visibility of the repository (`public` or `private`).
* `--resume_from_huggingface`: Resumes training from a state saved on Hugging Face Hub.
* `--async_upload`: Enables asynchronous uploading of models to the Hub, preventing it from blocking the training process.
* `--save_n_epoch_ratio`: Saves the model at a certain ratio of total epochs. For example, `5` will save at least 5 checkpoints throughout the training.
### 1.13. Advanced Attention Settings / 高度なAttention設定
* `--mem_eff_attn`: Use memory-efficient attention mechanism. This is an older implementation and `sdpa` or `xformers` are generally recommended.
* `--xformers`: Use xformers library for memory-efficient attention. Requires `pip install xformers`.
### 1.14. Advanced LR Scheduler Settings / 高度な学習率スケジューラ設定
* `--lr_scheduler_type`: Specifies a custom scheduler module.
* `--lr_scheduler_args`: Provides additional arguments to the custom scheduler (e.g., `"T_max=100"`).
* `--lr_decay_steps`: Sets the number of steps for the learning rate to decay.
* `--lr_scheduler_timescale`: The timescale for the inverse square root scheduler.
* `--lr_scheduler_min_lr_ratio`: Sets the minimum learning rate as a ratio of the initial learning rate for certain schedulers.
### 1.15. Differential Learning with LoRA / LoRAの差分学習
This technique involves merging a pre-trained LoRA into the base model before starting a new training session. This is useful for fine-tuning an existing LoRA or for learning the 'difference' from it.
* `--base_weights`: Path to one or more LoRA weight files to be merged into the base model before training begins.
* `--base_weights_multiplier`: A multiplier for the weights of the LoRA specified by `--base_weights`. You can specify multiple values if you provide multiple weights.
### 1.16. Other Miscellaneous Options / その他のオプション
* `--tokenizer_cache_dir`: Specifies a directory to cache the tokenizer, which is useful for offline training.
* `--scale_weight_norms`: Scales the weight norms of the LoRA modules. This can help prevent overfitting by controlling the magnitude of the weights. A value of `1.0` is a good starting point.
* `--disable_mmap_load_safetensors`: Disables memory-mapped loading for `.safetensors` files. This can speed up model loading in some environments like WSL.
## 2. Other Tips / その他のTips
@@ -165,8 +233,6 @@ Basic options are common with `train_network.py`.
<details>
<summary>日本語</summary>
---
# 高度な設定: SDXL LoRA学習スクリプト `sdxl_train_network.py` 詳細ガイド
このドキュメントでは、`sd-scripts` リポジトリに含まれる `sdxl_train_network.py` を使用した、SDXL (Stable Diffusion XL) モデルに対する LoRA (Low-Rank Adaptation) モデル学習の高度な設定オプションについて解説します。
@@ -333,10 +399,33 @@ SDXLは計算コストが高いため、キャッシュ機能が効果的です
* `--sample_at_first`
* 学習開始前にサンプル画像を生成します。
* `--sample_prompts="<プロンプトファイル>"`
* サンプル画像生成に使用するプロンプトを記述したファイル (`.txt`, `.toml`, `.json`) を指定します。書式は[gen\_img\_diffusers.py](gen_img_diffusers.py)に準じます。詳細は[ドキュメント](gen_img_README-ja.md)を参照してください。
* サンプル画像生成に使用するプロンプトを記述したファイル (`.txt`, `.toml`, `.json`) を指定します。
* `--sample_sampler="..."`
* サンプル画像生成時のサンプラー(スケジューラ)を指定します。`euler_a`, `dpm++_2m_karras` などが一般的です。選択肢は `--help` を参照してください。
#### プロンプトファイルの書式
プロンプトファイルは複数のプロンプトとオプションを含めることができます。例えば:
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#`で始まる行はコメントです。生成画像のオプションはプロンプトの後に `--n` のように指定できます。以下のオプションが使用可能です。
* `--n` 次のオプションまでがネガティブプロンプトです。CFGスケールが `1.0` の場合は無視されます。
* `--w` 生成画像の幅を指定します。
* `--h` 生成画像の高さを指定します。
* `--d` 生成画像のシード値を指定します。
* `--l` 生成画像のCFGスケールを指定します。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください。
* `--g` 埋め込みガイダンス付きモデルFLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください。
* `--s` 生成時のステップ数を指定します。
プロンプトの重み付け `( )``[ ]` はSD/SDXLモデルで動作し、FLUX.1など他のモデルでは動作しません。
### 1.8. Logging & Tracking 関連
* `--logging_dir="<ログディレクトリ>"`
@@ -381,7 +470,7 @@ SDXLは計算コストが高いため、キャッシュ機能が効果的です
* `--masked_loss`
* マスク画像に基づいてLoss計算領域を限定します。データセット設定で`conditioning_data_dir`にマスク画像(白黒)を指定する必要があります。詳細は[マスクロスについて](masked_loss_README.md)を参照してください。
### 1.10. 分散学習その他
### 1.10. 分散学習その他学習関連
* `--seed=N`
* 乱数シードを指定します。学習の再現性を確保したい場合に設定します。
@@ -397,9 +486,56 @@ SDXLは計算コストが高いため、キャッシュ機能が効果的です
* コマンドライン引数の代わりに`.toml`ファイルを使用/出力するオプション。
* **Accelerate/DeepSpeed関連:** (`--ddp_timeout`, `--ddp_gradient_as_bucket_view`, `--ddp_static_graph`)
* 分散学習時の詳細設定。通常はAccelerateの設定 (`accelerate config`) で十分です。DeepSpeedを使用する場合は、別途設定が必要です。
* `--initial_epoch=<integer>` 開始エポック番号を設定します。`1`で最初のエポック(未指定時と同じ)。注意:`initial_epoch`/`initial_step`はlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まります。
* `--initial_step=<integer>` 全エポックを含む開始ステップ番号を設定します。`0`で最初のステップ(未指定時と同じ)。`initial_epoch`を上書きします。
* `--skip_until_initial_step` `initial_step`に到達するまで学習をスキップします。
### 1.11. コンソールとログ
* `--console_log_level`: コンソール出力のログレベルを設定します。`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`から選択します。
* `--console_log_file`: コンソールのログを指定されたファイルに出力します。
* `--console_log_simple`: よりシンプルなログフォーマットを有効にします。
### 1.12. Hugging Face Hub 連携
* `--huggingface_repo_id`: モデルをアップロードするHugging Face Hubのリポジトリ名 (例: `your-username/your-model`)。
* `--huggingface_repo_type`: Hugging Face Hubのリポジトリの種類。通常は`model`です。
* `--huggingface_path_in_repo`: リポジトリ内でファイルをアップロードするパス。
* `--huggingface_token`: Hugging Face Hubの認証トークン。
* `--huggingface_repo_visibility`: リポジトリの公開設定 (`public`または`private`)。
* `--resume_from_huggingface`: Hugging Face Hubに保存された状態から学習を再開します。
* `--async_upload`: Hubへのモデルの非同期アップロードを有効にし、学習プロセスをブロックしないようにします。
* `--save_n_epoch_ratio`: 総エポック数に対する特定の比率でモデルを保存します。例えば`5`を指定すると、学習全体で少なくとも5つのチェックポイントが保存されます。
### 1.13. 高度なAttention設定
* `--mem_eff_attn`: メモリ効率の良いAttentionメカニズムを使用します。これは古い実装であり、一般的には`sdpa``xformers`の使用が推奨されます。
* `--xformers`: メモリ効率の良いAttentionのためにxformersライブラリを使用します。`pip install xformers`が必要です。
### 1.14. 高度な学習率スケジューラ設定
* `--lr_scheduler_type`: カスタムスケジューラモジュールを指定します。
* `--lr_scheduler_args`: カスタムスケジューラに追加の引数を渡します (例: `"T_max=100"`)。
* `--lr_decay_steps`: 学習率が減衰するステップ数を設定します。
* `--lr_scheduler_timescale`: 逆平方根スケジューラのタイムスケール。
* `--lr_scheduler_min_lr_ratio`: 特定のスケジューラについて、初期学習率に対する最小学習率の比率を設定します。
### 1.15. LoRAの差分学習
既存の学習済みLoRAをベースモデルにマージしてから、新たな学習を開始する手法です。既存LoRAのファインチューニングや、差分を学習させたい場合に有効です。
* `--base_weights`: 学習開始前にベースモデルにマージするLoRAの重みファイルを1つ以上指定します。
* `--base_weights_multiplier`: `--base_weights`で指定したLoRAの重みの倍率。複数指定も可能です。
### 1.16. その他のオプション
* `--tokenizer_cache_dir`: オフラインでの学習に便利なように、tokenizerをキャッシュするディレクトリを指定します。
* `--scale_weight_norms`: LoRAモジュールの重みのルムをスケーリングします。重みの大きさを制御することで過学習を防ぐ助けになります。`1.0`が良い出発点です。
* `--disable_mmap_load_safetensors`: `.safetensors`ファイルのメモリマップドローディングを無効にします。WSLなどの一部環境でモデルの読み込みを高速化できます。
## 2. その他のTips
* **VRAM使用量:** SDXL LoRA学習は多くのVRAMを必要とします。24GB VRAMでも設定によってはメモリ不足になることがあります。以下の設定でVRAM使用量を削減できます。
* `--mixed_precision="bf16"` または `"fp16"` (必須級)
* `--gradient_checkpointing` (強く推奨)
@@ -422,7 +558,4 @@ SDXLは計算コストが高いため、キャッシュ機能が効果的です
不明な点や詳細については、各スクリプトの `--help` オプションや、リポジトリ内の他のドキュメント、実装コード自体を参照してください。
---
</details>

View File

@@ -0,0 +1,291 @@
# How to use Textual Inversion training scripts / Textual Inversion学習スクリプトの使い方
This document explains how to train Textual Inversion embeddings using the `train_textual_inversion.py` and `sdxl_train_textual_inversion.py` scripts included in the `sd-scripts` repository.
<details>
<summary>日本語</summary>
このドキュメントでは、`sd-scripts` リポジトリに含まれる `train_textual_inversion.py` および `sdxl_train_textual_inversion.py` を使用してTextual Inversionの埋め込みを学習する方法について解説します。
</details>
## 1. Introduction / はじめに
[Textual Inversion](https://textual-inversion.github.io/) is a technique that teaches Stable Diffusion new concepts by learning new token embeddings. Instead of fine-tuning the entire model, it only optimizes the text encoder's token embeddings, making it a lightweight approach to teaching the model specific characters, objects, or artistic styles.
**Available Scripts:**
- `train_textual_inversion.py`: For Stable Diffusion v1.x and v2.x models
- `sdxl_train_textual_inversion.py`: For Stable Diffusion XL models
**Prerequisites:**
* The `sd-scripts` repository has been cloned and the Python environment has been set up.
* The training dataset has been prepared. For dataset preparation, please refer to the [Dataset Configuration Guide](config_README-en.md).
<details>
<summary>日本語</summary>
[Textual Inversion](https://textual-inversion.github.io/) は、新しいトークンの埋め込みを学習することで、Stable Diffusionに新しい概念を教える技術です。モデル全体をファインチューニングする代わりに、テキストエンコーダのトークン埋め込みのみを最適化するため、特定のキャラクター、オブジェクト、芸術的スタイルをモデルに教えるための軽量なアプローチです。
**利用可能なスクリプト:**
- `train_textual_inversion.py`: Stable Diffusion v1.xおよびv2.xモデル用
- `sdxl_train_textual_inversion.py`: Stable Diffusion XLモデル用
**前提条件:**
* `sd-scripts` リポジトリのクローンとPython環境のセットアップが完了していること。
* 学習用データセットの準備が完了していること。データセットの準備については[データセット設定ガイド](config_README-en.md)を参照してください。
</details>
## 2. Basic Usage / 基本的な使用方法
### 2.1. For Stable Diffusion v1.x/v2.x Models / Stable Diffusion v1.x/v2.xモデル用
```bash
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py \
--pretrained_model_name_or_path="path/to/model.safetensors" \
--dataset_config="dataset_config.toml" \
--output_dir="output" \
--output_name="my_textual_inversion" \
--save_model_as="safetensors" \
--token_string="mychar" \
--init_word="girl" \
--num_vectors_per_token=4 \
--max_train_steps=1600 \
--learning_rate=1e-6 \
--optimizer_type="AdamW8bit" \
--mixed_precision="fp16" \
--cache_latents \
--sdpa
```
### 2.2. For SDXL Models / SDXLモデル用
```bash
accelerate launch --num_cpu_threads_per_process 1 sdxl_train_textual_inversion.py \
--pretrained_model_name_or_path="path/to/sdxl_model.safetensors" \
--dataset_config="dataset_config.toml" \
--output_dir="output" \
--output_name="my_sdxl_textual_inversion" \
--save_model_as="safetensors" \
--token_string="mychar" \
--init_word="girl" \
--num_vectors_per_token=4 \
--max_train_steps=1600 \
--learning_rate=1e-6 \
--optimizer_type="AdamW8bit" \
--mixed_precision="fp16" \
--cache_latents \
--sdpa
```
<details>
<summary>日本語</summary>
上記のコマンドは実際には1行で書く必要がありますが、見やすさのために改行していますLinuxやMacでは行末に `\` を追加することで改行できます。Windowsの場合は、改行せずに1行で書くか、`^` を行末に追加してください。
</details>
## 3. Key Command-Line Arguments / 主要なコマンドライン引数
### 3.1. Textual Inversion Specific Arguments / Textual Inversion固有の引数
#### Core Parameters / コアパラメータ
* `--token_string="mychar"` **[Required]**
* Specifies the token string used in training. This must not exist in the tokenizer's vocabulary. In your training prompts, include this token string (e.g., if token_string is "mychar", use prompts like "mychar 1girl").
* 学習時に使用されるトークン文字列を指定します。tokenizerの語彙に存在しない文字である必要があります。学習時のプロンプトには、このトークン文字列を含める必要がありますtoken_stringが"mychar"なら、"mychar 1girl"のようなプロンプトを使用)。
* `--init_word="girl"`
* Specifies the word to use for initializing the embedding vector. Choose a word that is conceptually close to what you want to teach. Must be a single token.
* 埋め込みベクトルの初期化に使用する単語を指定します。教えたい概念に近い単語を選ぶとよいでしょう。単一のトークンである必要があります。
* `--num_vectors_per_token=4`
* Specifies how many embedding vectors to use for this token. More vectors provide greater expressiveness but consume more tokens from the 77-token limit.
* このトークンに使用する埋め込みベクトルの数を指定します。多いほど表現力が増しますが、77トークン制限からより多くのトークンを消費します。
* `--weights="path/to/existing_embedding.safetensors"`
* Loads pre-trained embeddings to continue training from. Optional parameter for transfer learning.
* 既存の埋め込みを読み込んで、そこから追加で学習します。転移学習のオプションパラメータです。
#### Template Options / テンプレートオプション
* `--use_object_template`
* Ignores captions and uses predefined object templates (e.g., "a photo of a {}"). Same as the original implementation.
* キャプションを無視して、事前定義された物体用テンプレート(例:"a photo of a {}")を使用します。公式実装と同じです。
* `--use_style_template`
* Ignores captions and uses predefined style templates (e.g., "a painting in the style of {}"). Same as the original implementation.
* キャプションを無視して、事前定義されたスタイル用テンプレート(例:"a painting in the style of {}")を使用します。公式実装と同じです。
### 3.2. Model and Dataset Arguments / モデル・データセット引数
For common model and dataset arguments, please refer to [LoRA Training Guide](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数). The following arguments work the same way:
* `--pretrained_model_name_or_path`
* `--dataset_config`
* `--v2`, `--v_parameterization`
* `--resolution`
* `--cache_latents`, `--vae_batch_size`
* `--enable_bucket`, `--min_bucket_reso`, `--max_bucket_reso`
<details>
<summary>日本語</summary>
一般的なモデル・データセット引数については、[LoRA学習ガイド](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数)を参照してください。以下の引数は同様に動作します:
* `--pretrained_model_name_or_path`
* `--dataset_config`
* `--v2`, `--v_parameterization`
* `--resolution`
* `--cache_latents`, `--vae_batch_size`
* `--enable_bucket`, `--min_bucket_reso`, `--max_bucket_reso`
</details>
### 3.3. Training Parameters / 学習パラメータ
For training parameters, please refer to [LoRA Training Guide](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数). Textual Inversion typically uses these settings:
* `--learning_rate=1e-6`: Lower learning rates are often used compared to LoRA training
* `--max_train_steps=1600`: Fewer steps are usually sufficient
* `--optimizer_type="AdamW8bit"`: Memory-efficient optimizer
* `--mixed_precision="fp16"`: Reduces memory usage
**Note:** Textual Inversion has lower memory requirements compared to full model fine-tuning, so you can often use larger batch sizes.
<details>
<summary>日本語</summary>
学習パラメータについては、[LoRA学習ガイド](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数)を参照してください。Textual Inversionでは通常以下の設定を使用します
* `--learning_rate=1e-6`: LoRA学習と比べて低い学習率がよく使用されます
* `--max_train_steps=1600`: より少ないステップで十分な場合が多いです
* `--optimizer_type="AdamW8bit"`: メモリ効率的なオプティマイザ
* `--mixed_precision="fp16"`: メモリ使用量を削減
**注意:** Textual Inversionはモデル全体のファインチューニングと比べてメモリ要件が低いため、多くの場合、より大きなバッチサイズを使用できます。
</details>
## 4. Dataset Preparation / データセット準備
### 4.1. Dataset Configuration / データセット設定
Create a TOML configuration file as described in the [Dataset Configuration Guide](config_README-en.md). Here's an example for Textual Inversion:
```toml
[general]
shuffle_caption = false
caption_extension = ".txt"
keep_tokens = 1
[[datasets]]
resolution = 512 # 1024 for SDXL
batch_size = 4 # Can use larger values than LoRA training
enable_bucket = true
[[datasets.subsets]]
image_dir = "path/to/images"
caption_extension = ".txt"
num_repeats = 10
```
### 4.2. Caption Guidelines / キャプションガイドライン
**Important:** Your captions must include the token string you specified. For example:
* If `--token_string="mychar"`, captions should be like: "mychar, 1girl, blonde hair, blue eyes"
* The token string can appear anywhere in the caption, but including it is essential
You can verify that your token string is being recognized by using `--debug_dataset`, which will show token IDs. Look for tokens with IDs ≥ 49408 (these are the new custom tokens).
<details>
<summary>日本語</summary>
**重要:** キャプションには指定したトークン文字列を含める必要があります。例:
* `--token_string="mychar"` の場合、キャプションは "mychar, 1girl, blonde hair, blue eyes" のようにします
* トークン文字列はキャプション内のどこに配置しても構いませんが、含めることが必須です
`--debug_dataset` を使用してトークン文字列が認識されているかを確認できます。これによりトークンIDが表示されます。ID ≥ 49408 のトークン(これらは新しいカスタムトークン)を探してください。
</details>
## 5. Advanced Configuration / 高度な設定
### 5.1. Multiple Token Vectors / 複数トークンベクトル
When using `--num_vectors_per_token` > 1, the system creates additional token variations:
- `--token_string="mychar"` with `--num_vectors_per_token=4` creates: "mychar", "mychar1", "mychar2", "mychar3"
For generation, you can use either the base token or all tokens together.
### 5.2. Memory Optimization / メモリ最適化
* Use `--cache_latents` to cache VAE outputs and reduce VRAM usage
* Use `--gradient_checkpointing` for additional memory savings
* For SDXL, use `--cache_text_encoder_outputs` to cache text encoder outputs
* Consider using `--mixed_precision="bf16"` on newer GPUs (RTX 30 series and later)
### 5.3. Training Tips / 学習のコツ
* **Learning Rate:** Start with 1e-6 and adjust based on results. Lower rates often work better than LoRA training.
* **Steps:** 1000-2000 steps are usually sufficient, but this varies by dataset size and complexity.
* **Batch Size:** Textual Inversion can handle larger batch sizes than full fine-tuning due to lower memory requirements.
* **Templates:** Use `--use_object_template` for characters/objects, `--use_style_template` for artistic styles.
<details>
<summary>日本語</summary>
* **学習率:** 1e-6から始めて、結果に基づいて調整してください。LoRA学習よりも低い率がよく機能します。
* **ステップ数:** 通常1000-2000ステップで十分ですが、データセットのサイズと複雑さによって異なります。
* **バッチサイズ:** メモリ要件が低いため、Textual Inversionは完全なファインチューニングよりも大きなバッチサイズを処理できます。
* **テンプレート:** キャラクター/オブジェクトには `--use_object_template`、芸術的スタイルには `--use_style_template` を使用してください。
</details>
## 6. Usage After Training / 学習後の使用方法
The trained Textual Inversion embeddings can be used in:
* **Automatic1111 WebUI:** Place the `.safetensors` file in the `embeddings` folder
* **ComfyUI:** Use the embedding file with appropriate nodes
* **Other Diffusers-based applications:** Load using the embedding path
In your prompts, simply use the token string you trained (e.g., "mychar") and the model will use the learned embedding.
<details>
<summary>日本語</summary>
学習したTextual Inversionの埋め込みは以下で使用できます
* **Automatic1111 WebUI:** `.safetensors` ファイルを `embeddings` フォルダに配置
* **ComfyUI:** 適切なノードで埋め込みファイルを使用
* **その他のDiffusersベースアプリケーション:** 埋め込みパスを使用して読み込み
プロンプトでは、学習したトークン文字列(例:"mychar")を単純に使用するだけで、モデルが学習した埋め込みを使用します。
</details>
## 7. Troubleshooting / トラブルシューティング
### Common Issues / よくある問題
1. **Token string already exists in tokenizer**
* Use a unique string that doesn't exist in the model's vocabulary
* Try adding numbers or special characters (e.g., "mychar123")
2. **No improvement after training**
* Ensure your captions include the token string
* Try adjusting the learning rate (lower values like 5e-7)
* Increase the number of training steps
* Use `--cache_latents`
<details>
<summary>日本語</summary>
1. **トークン文字列がtokenizerに既に存在する**
* モデルの語彙に存在しない固有の文字列を使用してください
* 数字や特殊文字を追加してみてください(例:"mychar123"
2. **学習後に改善が見られない**
* キャプションにトークン文字列が含まれていることを確認してください
* 学習率を調整してみてください5e-7のような低い値
* 学習ステップ数を増やしてください
3. **メモリ不足エラー**
* データセット設定でバッチサイズを減らしてください
* `--gradient_checkpointing` を使用してください
* `--cache_latents` を使用してください
</details>
For additional training options and advanced configurations, please refer to the [LoRA Training Guide](train_network.md) as many parameters are shared between training methods.

261
docs/validation.md Normal file
View File

@@ -0,0 +1,261 @@
# Validation Loss
Validation loss is a crucial metric for monitoring the training process of a model. It helps you assess how well your model is generalizing to data it hasn't seen during training, which is essential for preventing overfitting. By periodically evaluating the model on a separate validation dataset, you can gain insights into its performance and make more informed decisions about when to stop training or adjust hyperparameters.
This feature provides a stable and reliable validation loss metric by ensuring the validation process is deterministic.
<details>
<summary>日本語</summary>
Validation loss検証損失は、モデルの学習過程を監視するための重要な指標です。モデルが学習中に見ていないデータに対してどの程度汎化できているかを評価するのに役立ち、過学習を防ぐために不可欠です。個別の検証データセットで定期的にモデルを評価することで、そのパフォーマンスに関する洞察を得て、学習をいつ停止するか、またはハイパーパラメータを調整するかについて、より多くの情報に基づいた決定を下すことができます。
この機能は、検証プロセスが決定論的であることを保証することにより、安定して信頼性の高い検証損失指標を提供します。
</details>
## How It Works
When validation is enabled, a portion of your dataset is set aside specifically for this purpose. The script then runs a validation step at regular intervals, calculating the loss on this validation data.
To ensure that the validation loss is a reliable indicator of model performance, the process is deterministic. This means that for every validation run, the same random seed is used for noise generation and timestep selection. This consistency ensures that any fluctuations in the validation loss are due to changes in the model's weights, not random variations in the validation process itself.
The average loss across all validation steps is then logged, providing a single, clear metric to track.
For more technical details, please refer to the original pull request: [PR #1903](https://github.com/kohya-ss/sd-scripts/pull/1903).
<details>
<summary>日本語</summary>
検証が有効になると、データセットの一部がこの目的のために特別に確保されます。スクリプトは定期的な間隔で検証ステップを実行し、この検証データに対する損失を計算します。
検証損失がモデルのパフォーマンスの信頼できる指標であることを保証するために、プロセスは決定論的です。つまり、すべての検証実行で、ノイズ生成とタイムステップ選択に同じランダムシードが使用されます。この一貫性により、検証損失の変動が、検証プロセス自体のランダムな変動ではなく、モデルの重みの変化によるものであることが保証されます。
すべての検証ステップにわたる平均損失がログに記録され、追跡するための単一の明確な指標が提供されます。
より技術的な詳細については、元のプルリクエストを参照してください: [PR #1903](https://github.com/kohya-ss/sd-scripts/pull/1903).
</details>
## How to Use
### Enabling Validation
There are two primary ways to enable validation:
1. **Using a Dataset Config File (Recommended)**: You can specify a validation set directly within your dataset `.toml` file. This method offers the most control, allowing you to designate entire directories as validation sets or split a percentage of a specific subset for validation.
To use a whole directory for validation, add a subset and set `validation_split = 1.0`.
**Example: Separate Validation Set**
```toml
[[datasets]]
# ... training subset ...
[[datasets.subsets]]
image_dir = "path/to/train_images"
# ... other settings ...
# Validation subset
[[datasets.subsets]]
image_dir = "path/to/validation_images"
validation_split = 1.0 # Use this entire subset for validation
```
To use a fraction of a subset for validation, set `validation_split` to a value between 0.0 and 1.0.
**Example: Splitting a Subset**
```toml
[[datasets]]
# ... dataset settings ...
[[datasets.subsets]]
image_dir = "path/to/images"
validation_split = 0.1 # Use 10% of this subset for validation
```
2. **Using a Command-Line Argument**: For a simpler setup, you can use the `--validation_split` argument. This will take a random percentage of your *entire* training dataset for validation. This method is ignored if `validation_split` is defined in your dataset config file.
**Example Command:**
```bash
accelerate launch train_network.py ... --validation_split 0.1
```
This command will use 10% of the total training data for validation.
<details>
<summary>日本語</summary>
### 検証を有効にする
検証を有効にする主な方法は2つあります。
1. **データセット設定ファイルを使用する(推奨)**: データセットの`.toml`ファイル内で直接検証セットを指定できます。この方法は最も制御性が高く、ディレクトリ全体を検証セットとして指定したり、特定のサブセットのパーセンテージを検証用に分割したりすることができます。
ディレクトリ全体を検証に使用するには、サブセットを追加して`validation_split = 1.0`と設定します。
**例:個別の検証セット**
```toml
[[datasets]]
# ... training subset ...
[[datasets.subsets]]
image_dir = "path/to/train_images"
# ... other settings ...
# Validation subset
[[datasets.subsets]]
image_dir = "path/to/validation_images"
validation_split = 1.0 # このサブセット全体を検証に使用します
```
サブセットの一部を検証に使用するには、`validation_split`を0.0から1.0の間の値に設定します。
**例:サブセットの分割**
```toml
[[datasets]]
# ... dataset settings ...
[[datasets.subsets]]
image_dir = "path/to/images"
validation_split = 0.1 # このサブセットの10%を検証に使用します
```
2. **コマンドライン引数を使用する**: より簡単な設定のために、`--validation_split`引数を使用できます。これにより、*全*学習データセットのランダムなパーセンテージが検証に使用されます。この方法は、データセット設定ファイルで`validation_split`が定義されている場合は無視されます。
**コマンド例:**
```bash
accelerate launch train_network.py ... --validation_split 0.1
```
このコマンドは、全学習データの10%を検証に使用します。
</details>
### Configuration Options
| Argument | TOML Option | Description |
| --------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------- |
| `--validation_split` | `validation_split` | The fraction of the dataset to use for validation. The command-line argument applies globally, while the TOML option applies per-subset. The TOML setting takes precedence. |
| `--validate_every_n_steps` | | Run validation every N steps. |
| `--validate_every_n_epochs` | | Run validation every N epochs. If not specified, validation runs once per epoch by default. |
| `--max_validation_steps` | | The maximum number of batches to use for a single validation run. If not set, the entire validation dataset is used. |
| `--validation_seed` | `validation_seed` | A specific seed for the validation dataloader shuffling. If not set in the TOML file, the main training `--seed` is used. |
<details>
<summary>日本語</summary>
### 設定オプション
| 引数 | TOMLオプション | 説明 |
| --------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------- |
| `--validation_split` | `validation_split` | 検証に使用するデータセットの割合。コマンドライン引数は全体に適用され、TOMLオプションはサブセットごとに適用されます。TOML設定が優先されます。 |
| `--validate_every_n_steps` | | Nステップごとに検証を実行します。 |
| `--validate_every_n_epochs` | | Nエポックごとに検証を実行します。指定しない場合、デフォルトでエポックごとに1回検証が実行されます。 |
| `--max_validation_steps` | | 1回の検証実行に使用するバッチの最大数。設定しない場合、検証データセット全体が使用されます。 |
| `--validation_seed` | `validation_seed` | 検証データローダーのシャッフル用の特定のシード。TOMLファイルで設定されていない場合、メインの学習`--seed`が使用されます。 |
</details>
### Viewing the Results
The validation loss is logged to your tracking tool of choice (TensorBoard or Weights & Biases). Look for the metric `loss/validation` to monitor the performance.
<details>
<summary>日本語</summary>
### 結果の表示
検証損失は、選択した追跡ツールTensorBoardまたはWeights & Biasesに記録されます。パフォーマンスを監視するには、`loss/validation`という指標を探してください。
</details>
### Practical Example
Here is a complete example of how to run a LoRA training with validation enabled:
**1. Prepare your `dataset_config.toml`:**
```toml
[general]
shuffle_caption = true
keep_tokens = 1
[[datasets]]
resolution = "1024,1024"
batch_size = 2
[[datasets.subsets]]
image_dir = 'path/to/your_images'
caption_extension = '.txt'
num_repeats = 10
[[datasets.subsets]]
image_dir = 'path/to/your_validation_images'
caption_extension = '.txt'
validation_split = 1.0 # Use this entire subset for validation
```
**2. Run the training command:**
```bash
accelerate launch sdxl_train_network.py \
--pretrained_model_name_or_path="sd_xl_base_1.0.safetensors" \
--dataset_config="dataset_config.toml" \
--output_dir="output" \
--output_name="my_lora" \
--network_module=networks.lora \
--network_dim=32 \
--network_alpha=16 \
--save_every_n_epochs=1 \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--mixed_precision="bf16" \
--logging_dir=logs
```
The validation loss will be calculated once per epoch and saved to the `logs` directory, which you can view with TensorBoard.
<details>
<summary>日本語</summary>
### 実践的な例
検証を有効にしてLoRAの学習を実行する完全な例を次に示します。
**1. `dataset_config.toml`を準備します:**
```toml
[general]
shuffle_caption = true
keep_tokens = 1
[[datasets]]
resolution = "1024,1024"
batch_size = 2
[[datasets.subsets]]
image_dir = 'path/to/your_images'
caption_extension = '.txt'
num_repeats = 10
[[datasets.subsets]]
image_dir = 'path/to/your_validation_images'
caption_extension = '.txt'
validation_split = 1.0 # このサブセット全体を検証に使用します
```
**2. 学習コマンドを実行します:**
```bash
accelerate launch sdxl_train_network.py \
--pretrained_model_name_or_path="sd_xl_base_1.0.safetensors" \
--dataset_config="dataset_config.toml" \
--output_dir="output" \
--output_name="my_lora" \
--network_module=networks.lora \
--network_dim=32 \
--network_alpha=16 \
--save_every_n_epochs=1 \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--mixed_precision="bf16" \
--logging_dir=logs
```
検証損失はエポックごとに1回計算され、`logs`ディレクトリに保存されます。これはTensorBoardで表示できます。
</details>

View File

@@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github.
Using onnx for inference is recommended. Please install onnx with the following command:
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
pip install onnx onnxruntime-gpu
```
See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details.
The model weights will be automatically downloaded from Hugging Face.
# Usage
@@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
# Options
All options can be checked with `python tag_images_by_wd14_tagger.py --help`.
## General Options
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.

View File

@@ -5,9 +5,11 @@
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
pip install onnx onnxruntime-gpu
```
詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
# 使い方
@@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
# オプション
全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。
## 一般オプション
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。

View File

@@ -1,9 +1,11 @@
import argparse
import csv
import json
import math
import os
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
@@ -29,8 +31,22 @@ SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
TAG_JSON_FILE = "tag_mapping.json"
def preprocess_image(image: Image.Image) -> np.ndarray:
# If image has transparency, convert to RGBA. If not, convert to RGB
if image.mode in ("RGBA", "LA") or "transparency" in image.info:
image = image.convert("RGBA")
elif image.mode != "RGB":
image = image.convert("RGB")
# If image is RGBA, combine with white background
if image.mode == "RGBA":
background = Image.new("RGB", image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3]) # Use alpha channel as mask
image = background
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
@@ -49,67 +65,103 @@ def preprocess_image(image):
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __init__(self, image_paths: list[str], batch_size: int):
self.image_paths = image_paths
self.batch_size = batch_size
def __len__(self):
return len(self.images)
return math.ceil(len(self.image_paths) / self.batch_size)
def __getitem__(self, idx):
img_path = str(self.images[idx])
def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]:
image_index_start = batch_index * self.batch_size
image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths))
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
batch_image_paths = []
images = []
image_sizes = []
for idx in range(image_index_start, image_index_end):
img_path = str(self.image_paths[idx])
return (image, img_path)
try:
image = Image.open(img_path)
image_size = image.size
image = preprocess_image(image)
batch_image_paths.append(img_path)
images.append(image)
image_sizes.append(image_size)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3))
return batch_image_paths, images, image_sizes
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
def collate_fn_no_op(batch):
"""Collate function that does nothing and returns the batch as is."""
return batch
def main(args):
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
# given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir"
# so we split it to "namespace/reponame" and "subdir"
tokens = args.repo_id.split("/")
if len(tokens) > 2:
repo_id = "/".join(tokens[:2])
subdir = "/".join(tokens[2:])
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir)
onnx_model_name = "model_optimized.onnx"
default_format = False
else:
repo_id = args.repo_id
subdir = None
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"))
onnx_model_name = "model.onnx"
default_format = True
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(model_location) or args.force_download:
os.makedirs(args.model_dir, exist_ok=True)
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
if subdir is None:
# SmilingWolf structure
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
local_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
)
for file in files:
hf_hub_download(
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
local_dir=os.path.join(model_location, SUB_DIR),
local_dir=model_location,
force_download=True,
)
else:
# another structure
files = [onnx_model_name, "tag_mapping.json"]
for file in files:
hf_hub_download(
repo_id=repo_id,
filename=file,
subfolder=subdir,
local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified
force_download=True,
)
for file in files:
hf_hub_download(
repo_id=args.repo_id,
filename=file,
local_dir=model_location,
force_download=True,
)
else:
logger.info("using existing wd14 tagger model")
@@ -118,7 +170,7 @@ def main(args):
import onnx
import onnxruntime as ort
onnx_path = f"{model_location}/model.onnx"
onnx_path = os.path.join(model_location, onnx_model_name)
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
@@ -150,39 +202,30 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
provider_options=[{"device_type": "GPU", "precision": "FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
providers = (
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
)
logger.info(f"Using onnxruntime providers: {providers}")
ort_sess = ort.InferenceSession(onnx_path, providers=providers)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{model_location}")
# We read the CSV file manually to avoid adding dependencies.
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# preprocess tags in advance
if args.character_tag_expand:
for i, tag in enumerate(character_tags):
def expand_character_tags(char_tags):
for i, tag in enumerate(char_tags):
if tag.endswith(")"):
# chara_name_(series) -> chara_name, series
# chara_name_(costume)_(series) -> chara_name_(costume), series
@@ -191,35 +234,95 @@ def main(args):
if character_tag.endswith("_"):
character_tag = character_tag[:-1]
series_tag = tags[-1].replace(")", "")
character_tags[i] = character_tag + args.caption_separator + series_tag
char_tags[i] = character_tag + args.caption_separator + series_tag
if args.remove_underscore:
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
def remove_underscore(tags):
return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags]
if args.tag_replacement is not None:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
def process_tag_replacement(tags: list[str], tag_replacements_arg: str) -> list[str]:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;,
# so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means
# `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD`
escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####")
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
for tag_replacements_arg in tag_replacements:
tags = tag_replacements_arg.split(",") # source, target
assert (
len(tags) == 2
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
if source in general_tags:
general_tags[general_tags.index(source)] = target
elif source in character_tags:
character_tags[character_tags.index(source)] = target
elif source in rating_tags:
rating_tags[rating_tags.index(source)] = target
if source in tags:
tags[tags.index(source)] = target
return tags
if default_format:
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
if args.character_tag_expand:
expand_character_tags(character_tags)
if args.remove_underscore:
rating_tags = remove_underscore(rating_tags)
character_tags = remove_underscore(character_tags)
general_tags = remove_underscore(general_tags)
if args.tag_replacement is not None:
process_tag_replacement(rating_tags, args.tag_replacement)
process_tag_replacement(general_tags, args.tag_replacement)
process_tag_replacement(character_tags, args.tag_replacement)
else:
with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f:
tag_mapping = json.load(f)
rating_tags = []
general_tags = []
character_tags = []
tag_id_to_tag_mapping = {}
tag_id_to_category_mapping = {}
for tag_id, tag_info in tag_mapping.items():
tag = tag_info["tag"]
category = tag_info["category"]
assert category in [
"Rating",
"General",
"Character",
"Copyright",
"Meta",
"Model",
"Quality",
"Artist",
], f"unexpected category: {category}"
if args.remove_underscore:
tag = remove_underscore([tag])[0]
if args.tag_replacement is not None:
tag = process_tag_replacement([tag], args.tag_replacement)[0]
if category == "Character" and args.character_tag_expand:
tag_list = [tag]
expand_character_tags(tag_list)
tag = tag_list[0]
tag_id_to_tag_mapping[int(tag_id)] = tag
tag_id_to_category_mapping[int(tag_id)] = category
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
image_paths = [str(ip) for ip in image_paths]
tag_freq = {}
@@ -232,59 +335,150 @@ def main(args):
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[dict[str, dict]]:
nonlocal args, default_format, model, ort_sess, input_name, tag_freq
imgs = path_imgs[1]
result = {}
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
if not default_format:
imgs = imgs.transpose(0, 3, 1, 2) # to NCHW
imgs = imgs / 127.5 - 1.0
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
probs = probs[: len(imgs)] # remove padding
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs):
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
other_tag_text = ""
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if default_format:
# 最初の4つ以降はタグなのでconfidencethreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if tag_name not in undesired_tags:
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
else:
# apply sigmoid to probabilities
prob = 1 / (1 + np.exp(-prob))
rating_max_prob = -1
rating_tag = None
quality_max_prob = -1
quality_tag = None
character_tags = []
min_thres = min(
args.thresh,
args.general_threshold,
args.character_threshold,
args.copyright_threshold,
args.meta_threshold,
args.model_threshold,
args.artist_threshold,
)
prob_indices = np.where(prob >= min_thres)[0]
# for i, p in enumerate(prob):
for i in prob_indices:
if i not in tag_id_to_tag_mapping:
continue
p = prob[i]
tag_name = tag_id_to_tag_mapping[i]
category = tag_id_to_category_mapping[i]
if tag_name in undesired_tags:
continue
if category == "Rating":
if p > rating_max_prob:
rating_max_prob = p
rating_tag = tag_name
rating_tag_text = tag_name
continue
elif category == "Quality":
if p > quality_max_prob:
quality_max_prob = p
quality_tag = tag_name
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
other_tag_text += caption_separator + tag_name
continue
if category == "General" and p >= args.general_threshold:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if tag_name not in undesired_tags:
combined_tags.append((tag_name, p))
elif category == "Character" and p >= args.character_threshold:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
if args.character_tags_first: # we separate character tags
character_tags.append((tag_name, p))
else:
combined_tags.append(tag_name)
combined_tags.append((tag_name, p))
elif (
(category == "Copyright" and p >= args.copyright_threshold)
or (category == "Meta" and p >= args.meta_threshold)
or (category == "Model" and p >= args.model_threshold)
or (category == "Artist" and p >= args.artist_threshold)
):
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
other_tag_text += f"{caption_separator}{tag_name} ({category})"
combined_tags.append((tag_name, p))
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
# sort by probability
combined_tags.sort(key=lambda x: x[1], reverse=True)
if character_tags:
character_tags.sort(key=lambda x: x[1], reverse=True)
combined_tags = character_tags + combined_tags
combined_tags = [t[0] for t in combined_tags] # remove probability
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
if quality_tag is not None:
if args.use_quality_tags_as_last_tag:
combined_tags.append(quality_tag)
elif args.use_quality_tags:
combined_tags.insert(0, quality_tag)
if rating_tag is not None:
if args.use_rating_tags_as_last_tag:
combined_tags.append(rating_tag)
elif args.use_rating_tags:
combined_tags.insert(0, rating_tag)
# 一番最初に置くタグを指定する
# Always put some tags at the beginning
@@ -299,6 +493,8 @@ def main(args):
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[len(caption_separator) :]
if len(other_tag_text) > 0:
other_tag_text = other_tag_text[len(caption_separator) :]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
@@ -320,55 +516,79 @@ def main(args):
# Create new tag_text
tag_text = caption_separator.join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if not args.output_path:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
else:
entry = {"tags": tag_text, "image_size": list(image_size)}
result[image_path] = entry
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if other_tag_text:
logger.info(f"\tOther tags: {other_tag_text}")
return result
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
dataset = ImageLoadingPrepDataset(image_paths, args.batch_size)
data = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
batch_size=1,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
collate_fn=collate_fn_no_op,
drop_last=False,
)
else:
data = [[(None, ip)] for ip in image_paths]
# data = [[(ip, None, None)] for ip in image_paths]
data = [[]]
for ip in image_paths:
if len(data[-1]) >= args.batch_size:
data.append([])
data[-1].append((ip, None, None))
b_imgs = []
results = {}
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
if data_entry is None or len(data_entry) == 0:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if data_entry[0][1] is None:
# No preloaded image, need to load
images = []
image_sizes = []
for image_path, _, _ in data_entry:
image = Image.open(image_path)
image_size = image.size
image = preprocess_image(image)
images.append(image)
image_sizes.append(image_size)
b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes)
else:
b_imgs = data_entry[0]
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
r = run_batch(b_imgs)
if args.output_path and r is not None:
results.update(r)
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
if args.output_path:
if args.output_path.endswith(".jsonl"):
# optional JSONL metadata
with open(args.output_path, "wt", encoding="utf-8") as f:
for image_path, entry in results.items():
f.write(
json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n"
)
else:
# standard JSON metadata
with open(args.output_path, "wt", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)
logger.info(f"captions saved to {args.output_path}")
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
@@ -381,9 +601,7 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--repo_id",
type=str,
@@ -401,15 +619,19 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--output_path",
type=str,
default=None,
help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパスjson形式。このオプションが設定されている場合、キャプションはこのファイルに保存されます",
)
parser.add_argument(
"--caption_extention",
type=str,
@@ -432,7 +654,36 @@ def setup_parser() -> argparse.ArgumentParser:
"--character_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags"
" / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる",
)
parser.add_argument(
"--meta_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags"
" / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる",
)
parser.add_argument(
"--model_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags"
" / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる",
)
parser.add_argument(
"--copyright_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags"
" / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる",
)
parser.add_argument(
"--artist_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for artist category, same as --thresh if omitted. set above 1 to disable artist tags"
" / artistカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとartistタグを無効化できる",
)
parser.add_argument(
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
@@ -442,9 +693,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--undesired_tags",
type=str,
@@ -454,20 +703,34 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
"--use_rating_tags",
action="store_true",
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
"--use_rating_tags_as_last_tag",
action="store_true",
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
"--use_quality_tags",
action="store_true",
help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_quality_tags_as_last_tag",
action="store_true",
help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first",
action="store_true",
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
@@ -512,5 +775,13 @@ if __name__ == "__main__":
args.general_threshold = args.thresh
if args.character_threshold is None:
args.character_threshold = args.thresh
if args.meta_threshold is None:
args.meta_threshold = args.thresh
if args.model_threshold is None:
args.model_threshold = args.thresh
if args.copyright_threshold is None:
args.copyright_threshold = args.thresh
if args.artist_threshold is None:
args.artist_threshold = args.thresh
main(args)

View File

@@ -456,13 +456,13 @@ if __name__ == "__main__":
# load clip_l (skip for chroma model)
if args.model_type == "flux":
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device, disable_mmap=True)
clip_l.eval()
else:
clip_l = None
logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device, disable_mmap=True)
t5xxl.eval()
# if is_fp8(clip_l_dtype):
@@ -471,7 +471,9 @@ if __name__ == "__main__":
# t5xxl = accelerator.prepare(t5xxl)
# DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
is_schnell, model = flux_utils.load_flow_model(
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype

View File

@@ -1,5 +1,6 @@
import itertools
import json
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
@@ -20,7 +21,8 @@ import diffusers
import numpy as np
import torch
from library.device_utils import init_ipex, clean_memory, get_preferred_device
from library.device_utils import init_ipex
from library.strategy_sd import SdTokenizeStrategy
init_ipex()
@@ -60,6 +62,7 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from library.custom_train_functions import pyramid_noise_like
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from library.utils import setup_logging, add_logging_arguments
@@ -434,6 +437,7 @@ class PipelineLike:
img2img_noise=None,
clip_guide_images=None,
emb_normalize_mode: str = "original",
force_scheduler_zero_steps_offset: bool = False,
**kwargs,
):
# TODO support secondary prompt
@@ -707,7 +711,10 @@ class PipelineLike:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
if force_scheduler_zero_steps_offset:
offset = 0
else:
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
@@ -859,7 +866,7 @@ class PipelineLike:
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
@@ -1362,97 +1369,177 @@ def preprocess_mask(mask):
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
def handle_dynamic_prompt_variants(prompt, repeat_count):
def handle_dynamic_prompt_variants(prompt, repeat_count, seed_random, seeds=None):
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
if not founds:
return [prompt]
return [prompt], seeds
# make each replacement for each variant
enumerating = False
replacers = []
for found in founds:
# if "e$$" is found, enumerate all variants
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
# Prepare seeds list
if seeds is None:
seeds = []
while len(seeds) < repeat_count:
seeds.append(seed_random.randint(0, 2**32 - 1))
separator = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
# Escape braces
prompt = prompt.replace(r"\{", "").replace(r"\}", "")
# parse count range
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
# Process nested dynamic prompts recursively
prompts = [prompt] * repeat_count
has_dynamic = True
while has_dynamic:
has_dynamic = False
new_prompts = []
for i, prompt in enumerate(prompts):
seed = seeds[i] if i < len(seeds) else seeds[0] # if enumerating, use the first seed
# find innermost dynamic prompts
# find outer dynamic prompt and temporarily replace them with placeholders
deepest_nest_level = 0
nest_level = 0
for c in prompt:
if c == "{":
nest_level += 1
deepest_nest_level = max(deepest_nest_level, nest_level)
elif c == "}":
nest_level -= 1
if deepest_nest_level == 0:
new_prompts.append(prompt)
continue # no more dynamic prompts
# find positions of innermost dynamic prompts
positions = []
nest_level = 0
start_pos = -1
for i, c in enumerate(prompt):
if c == "{":
nest_level += 1
if nest_level == deepest_nest_level:
start_pos = i
elif c == "}":
if nest_level == deepest_nest_level:
end_pos = i + 1
positions.append((start_pos, end_pos))
nest_level -= 1
# extract innermost dynamic prompts
innermost_founds = []
for start, end in positions:
segment = prompt[start:end]
m = RE_DYNAMIC_PROMPT.match(segment)
if m:
innermost_founds.append((m, start, end))
if not innermost_founds:
new_prompts.append(prompt)
continue
has_dynamic = True
# make each replacement for each variant
enumerating = False
replacers = []
for found, start, end in innermost_founds:
# if "e$$" is found, enumerate all variants
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
separator = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
# parse count range
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
if found_enumerating:
# make function to enumerate all combinations
def make_replacer_enum(vari, cr, sep):
def replacer(rnd=random):
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
return replacer
replacers.append(make_replacer_enum(variants, count_range, separator))
else:
# make function to choose random combinations
def make_replacer_single(vari, cr, sep):
def replacer(rnd=random):
count = rnd.randint(cr[0], cr[1])
comb = rnd.sample(vari, count)
return [sep.join(comb)]
return replacer
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
rnd = random.Random(seed)
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
# reverse the lists to replace from end to start, keep positions correct
innermost_founds.reverse()
replacers.reverse()
current = prompt
for (found, start, end), replacer in zip(innermost_founds, replacers):
current = current[:start] + replacer(rnd)[0] + current[end:]
new_prompts.append(current)
else:
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
# if enumerating, iterate all combinations for previous prompts, all seeds are same
processing_prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None:
# make all combinations for existing prompts
repleced_prompts = []
for current in processing_prompts:
replacements = replacer(rnd)
for replacement in replacements:
repleced_prompts.append(
current.replace(found.group(0), replacement, 1)
) # This does not work if found is duplicated
processing_prompts = repleced_prompts
if found_enumerating:
# make function to enumerate all combinations
def make_replacer_enum(vari, cr, sep):
def replacer():
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
for i in range(len(processing_prompts)):
processing_prompts[i] = processing_prompts[i].replace(found.group(0), replacer(rnd)[0], 1)
return replacer
new_prompts.extend(processing_prompts)
replacers.append(make_replacer_enum(variants, count_range, separator))
else:
# make function to choose random combinations
def make_replacer_single(vari, cr, sep):
def replacer():
count = random.randint(cr[0], cr[1])
comb = random.sample(vari, count)
return [sep.join(comb)]
prompts = new_prompts
return replacer
# Restore escaped braces
for i in range(len(prompts)):
prompts[i] = prompts[i].replace("", "{").replace("", "}")
if enumerating:
# adjust seeds list
new_seeds = []
for _ in range(len(prompts)):
new_seeds.append(seeds[0]) # use the first seed for all
seeds = new_seeds
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
prompts = []
for _ in range(repeat_count):
current = prompt
for found, replacer in zip(founds, replacers):
current = current.replace(found.group(0), replacer()[0], 1)
prompts.append(current)
else:
# if enumerating, iterate all combinations for previous prompts
prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None:
# make all combinations for existing prompts
new_prompts = []
for current in prompts:
replecements = replacer()
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement, 1))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
for i in range(len(prompts)):
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
return prompts
return prompts, seeds
# endregion
@@ -1612,7 +1699,8 @@ def main(args):
tokenizers = [tokenizer1, tokenizer2]
else:
if use_stable_diffusion_format:
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = SdTokenizeStrategy(args.v2, max_length=None, tokenizer_cache_dir=args.tokenizer_cache_dir)
tokenizer = tokenize_strategy.tokenizer
tokenizers = [tokenizer]
# schedulerを用意する
@@ -1719,6 +1807,9 @@ def main(args):
if scheduler_module is not None:
scheduler_module.torch = TorchRandReplacer(noise_manager)
if args.zero_terminal_snr:
sched_init_args["rescale_betas_zero_snr"] = True
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
@@ -1727,6 +1818,9 @@ def main(args):
**sched_init_args,
)
# if args.zero_terminal_snr:
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(scheduler)
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
# # clip_sample=Trueにする
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
@@ -1868,7 +1962,7 @@ def main(args):
if not is_sdxl:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
weight = 1.0 if not args.control_net_multipliers or len(args.control_net_multipliers) <= i else args.control_net_multipliers[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
@@ -2355,7 +2449,9 @@ def main(args):
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
images_1st = torch.nn.functional.interpolate(
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
images_1st,
(batch[0].ext.height // 8, batch[0].ext.width // 8),
mode="bicubic",
) # , antialias=True)
images_1st = images_1st.to(org_dtype)
@@ -2464,6 +2560,20 @@ def main(args):
torch.manual_seed(seed)
start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
# pyramid noise
if args.pyramid_noise_prob is not None and random.random() < args.pyramid_noise_prob:
min_discount, max_discount = args.pyramid_noise_discount_range
discount = torch.rand(1, device=device, dtype=dtype) * (max_discount - min_discount) + min_discount
logger.info(f"apply pyramid noise to start code: {start_code[i].shape}, discount: {discount.item()}")
start_code[i] = pyramid_noise_like(start_code[i].unsqueeze(0), device=device, discount=discount).squeeze(0)
# noise offset
if args.noise_offset_prob is not None and random.random() < args.noise_offset_prob:
min_offset, max_offset = args.noise_offset_range
noise_offset = torch.randn(1, device=device, dtype=dtype) * (max_offset - min_offset) + min_offset
logger.info(f"apply noise offset to start code: {start_code[i].shape}, offset: {noise_offset.item()}")
start_code[i] += noise_offset
# make each noises
for j in range(steps * scheduler_num_noises_per_step):
noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
@@ -2532,6 +2642,7 @@ def main(args):
clip_prompts=clip_prompts,
clip_guide_images=guide_images,
emb_normalize_mode=args.emb_normalize_mode,
force_scheduler_zero_steps_offset=args.force_scheduler_zero_steps_offset,
)
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images
@@ -2624,7 +2735,16 @@ def main(args):
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
seeds = None
m = re.search(r" --d ([\d+,]+)", raw_prompt, re.IGNORECASE)
if m:
seeds = [int(d) for d in m[0][5:].split(",")]
logger.info(f"seeds: {seeds}")
raw_prompt = raw_prompt[: m.start()] + raw_prompt[m.end() :]
raw_prompts, prompt_seeds = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt, seed_random, seeds)
if prompt_seeds is not None:
seeds = prompt_seeds
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
@@ -2644,8 +2764,8 @@ def main(args):
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seed = None
seeds = None
# seed = None
# seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
@@ -2727,11 +2847,11 @@ def main(args):
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
logger.info(f"seeds: {seeds}")
continue
# m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
# if m: # seed
# seeds = [int(d) for d in m.group(1).split(",")]
# logger.info(f"seeds: {seeds}")
# continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
@@ -3012,6 +3132,27 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
)
parser.add_argument(
"--zero_terminal_snr",
action="store_true",
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
)
parser.add_argument(
"--pyramid_noise_prob", type=float, default=None, help="probability for pyramid noise / ピラミッドノイズの確率"
)
parser.add_argument(
"--pyramid_noise_discount_range",
type=float,
nargs=2,
default=None,
help="discount range for pyramid noise / ピラミッドノイズの割引範囲",
)
parser.add_argument(
"--noise_offset_prob", type=float, default=None, help="probability for noise offset / ノイズオフセットの確率"
)
parser.add_argument(
"--noise_offset_range", type=float, nargs=2, default=None, help="range for noise offset / ノイズオフセットの範囲"
)
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
parser.add_argument(
@@ -3250,6 +3391,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=["original", "none", "abs"],
help="embedding normalization mode / embeddingの正規化モード",
)
parser.add_argument(
"--force_scheduler_zero_steps_offset",
action="store_true",
help="force scheduler steps offset to zero"
+ " / スケジューラのステップオフセットをスケジューラ設定の `steps_offset` の値に関わらず強制的にゼロにする",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,717 @@
import argparse
import copy
import gc
from typing import Any, Optional, Union, cast
import os
import time
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from accelerate import Accelerator, PartialState
from library import flux_utils, hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
flux_train_utils,
hunyuan_image_models,
hunyuan_image_text_encoder,
hunyuan_image_utils,
hunyuan_image_vae,
sd3_train_utils,
strategy_base,
strategy_hunyuan_image,
train_util,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# region sampling
# TODO commonize with flux_utils
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
dit: hunyuan_image_models.HYImageDiffusionTransformer,
vae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
dit = accelerator.unwrap_model(dit)
dit = cast(hunyuan_image_models.HYImageDiffusionTransformer, dit)
dit.switch_block_swap_for_inference()
if text_encoders is not None:
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
dit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
dit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
dit: hunyuan_image_models.HYImageDiffusionTransformer,
text_encoders: Optional[list[nn.Module]],
vae: hunyuan_image_vae.HunyuanVAE2D,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
cfg_scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
prompt: str = prompt_dict.get("prompt", "")
flow_shift: float = prompt_dict.get("flow_shift", 5.0)
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
if cfg_scale != 1.0:
logger.info(f"CFG scale: {cfg_scale}")
logger.info(f"flow_shift: {flow_shift}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
# print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
# print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
vl_embed, vl_mask, byt5_embed, byt5_mask, ocr_mask = encode_prompt(prompt)
arg_c = {
"embed": vl_embed,
"mask": vl_mask,
"embed_byt5": byt5_embed,
"mask_byt5": byt5_mask,
"ocr_mask": ocr_mask,
"prompt": prompt,
}
# encode negative prompts
if cfg_scale != 1.0:
neg_vl_embed, neg_vl_mask, neg_byt5_embed, neg_byt5_mask, neg_ocr_mask = encode_prompt(negative_prompt)
arg_c_null = {
"embed": neg_vl_embed,
"mask": neg_vl_mask,
"embed_byt5": neg_byt5_embed,
"mask_byt5": neg_byt5_mask,
"ocr_mask": neg_ocr_mask,
"prompt": negative_prompt,
}
else:
arg_c_null = None
gen_args = SimpleNamespace(
image_size=(height, width),
infer_steps=sample_steps,
flow_shift=flow_shift,
guidance_scale=cfg_scale,
fp8=args.fp8_scaled,
apg_start_step_ocr=38,
apg_start_step_general=5,
guidance_rescale=0.0,
guidance_rescale_apg=0.0,
)
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
dit_is_training = dit.training
dit.eval()
x = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed)
if dit_is_training:
dit.train()
clean_memory_on_device(accelerator.device)
# latent to image
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with torch.no_grad():
x = x / vae.scaling_factor
x = vae.decode(x.to(vae.device, dtype=vae.dtype))
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# endregion
class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
self.rotary_pos_emb_cache = {}
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.mixed_precision == "fp16":
logger.warning(
"mixed_precision bf16 is recommended for HunyuanImage-2.1 / HunyuanImage-2.1ではmixed_precision bf16が推奨されます"
)
if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled:
logger.warning(
"fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください"
)
if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet):
logger.info(
"fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます"
)
args.fp8_base = False
args.fp8_base_unet = False
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
vl_dtype = torch.float8_e4m3fn if args.fp8_vl else torch.bfloat16
vl_device = "cpu" # loading to cpu and move to gpu later in cache_text_encoder_outputs_if_needed
_, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl(
args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
_, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5(
args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
vae = hunyuan_image_vae.load_vae(
args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors, chunk_size=args.vae_chunk_size
)
vae.to(dtype=torch.float16) # VAE is always fp16
vae.eval()
model_version = hunyuan_image_utils.MODEL_VERSION_2_1
return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
if args.cache_text_encoder_outputs:
logger.info("Replace text encoders with dummy models to save memory")
# This doesn't free memory, so we move text encoders to meta device in cache_text_encoder_outputs_if_needed
text_encoders = [flux_utils.dummy_clip_l() for _ in text_encoders]
clean_memory_on_device(accelerator.device)
gc.collect()
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
if args.attn_mode is not None:
attn_mode = args.attn_mode
logger.info(f"Loading DiT model with attn_mode: {attn_mode}, split_attn: {args.split_attn}, fp8_scaled: {args.fp8_scaled}")
model = hunyuan_image_models.load_hunyuan_image_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
args.split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
)
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
return model, text_encoders
def get_tokenize_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageTokenizeStrategy(args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_hunyuan_image.HunyuanImageTokenizeStrategy):
return [tokenize_strategy.vlm_tokenizer, tokenize_strategy.byt5_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
def get_text_encoding_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders
def get_text_encoders_train_flags(self, args, text_encoders):
# HunyuanImage-2.1 does not support training VLM or byT5
return [False, False]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_hunyuan_image.HunyuanImageTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
vlm_device = "cpu" if args.text_encoder_cpu else accelerator.device
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae to cpu to save memory")
org_vae_device = vae.device
vae.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info(f"move text encoders to {vlm_device} to encode and cache text encoder outputs")
text_encoders[0].to(vlm_device)
text_encoders[1].to(vlm_device)
# VLM (bf16) and byT5 (fp16) are used for encoding, so we cannot use autocast here
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_hunyuan_image.HunyuanImageTokenizeStrategy = (
strategy_base.TokenizeStrategy.get_strategy()
)
text_encoding_strategy: strategy_hunyuan_image.HunyuanImageTextEncodingStrategy = (
strategy_base.TextEncodingStrategy.get_strategy()
)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# text encoders are not needed for training, so we move to meta device
logger.info("move text encoders to meta device to save memory")
text_encoders = [te.to("meta") for te in text_encoders]
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae back to original device")
vae.to(org_vae_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(vlm_device)
text_encoders[1].to(vlm_device)
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
sample_images(accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, vae: hunyuan_image_vae.HunyuanVAE2D, images):
return vae.encode(images).sample()
def shift_scale_latents(self, args, latents):
# for encoding, we need to scale the latents
return latents * hunyuan_image_vae.LATENT_SCALING_FACTOR
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: hunyuan_image_models.HYImageDiffusionTransformer,
network,
weight_dtype,
train_unet,
is_train=True,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, _, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
# bfloat16 is too low precision for 0-1000 TODO fix get_noisy_model_input_and_timesteps
timesteps = (sigmas[:, 0, 0, 0] * 1000).to(torch.int64)
# print(
# f"timestep: {timesteps}, noisy_model_input shape: {noisy_model_input.shape}, mean: {noisy_model_input.mean()}, std: {noisy_model_input.std()}"
# )
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Predict the noise residual
# ocr_mask is for inference only, so it is not used here
vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = text_encoder_conds
# print(f"embed shape: {vlm_embed.shape}, mean: {vlm_embed.mean()}, std: {vlm_embed.std()}")
# print(f"embed_byt5 shape: {byt5_embed.shape}, mean: {byt5_embed.mean()}, std: {byt5_embed.std()}")
# print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}")
# print(f"mask shape: {vlm_mask.shape}, sum: {vlm_mask.sum()}")
# print(f"mask_byt5 shape: {byt5_mask.shape}, sum: {byt5_mask.sum()}")
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = unet(
noisy_model_input, timesteps, vlm_embed, vlm_mask, byt5_embed, byt5_mask # , self.rotary_pos_emb_cache
)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss
target = noise - latents
# differential output preservation is not used for HunyuanImage-2.1 currently
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# do not support text encoder training for HunyuanImage-2.1
pass
def cast_text_encoder(self, args):
return False # VLM is bf16, byT5 is fp16, so do not cast to other dtype
def cast_vae(self, args):
return False # VAE is fp16, so do not cast to other dtype
def cast_unet(self, args):
return not args.fp8_scaled # if fp8_scaled is used, do not cast to other dtype
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
# fp8 text encoder for HunyuanImage-2.1 is not supported currently
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
model: hunyuan_image_models.HYImageDiffusionTransformer = unet
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return model
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
parser.add_argument(
"--text_encoder",
type=str,
help="path to Qwen2.5-VL (*.sft or *.safetensors), should be bfloat16 / Qwen2.5-VLのパス*.sftまたは*.safetensors、bfloat16が前提",
)
parser.add_argument(
"--byt5",
type=str,
help="path to byt5 (*.sft or *.safetensors), should be float16 / byt5のパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="raw",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling). Default is raw unlike FLUX.1."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用。デフォルトはFLUX.1とは異なりrawです。",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=5.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 5.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは5.0。",
)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument("--fp8_vl", action="store_true", help="Use fp8 for VLM text encoder / VLMテキストエンコーダにfp8を使用する")
parser.add_argument(
"--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders / テキストエンコーダをCPUで推論する"
)
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None, # default is None (no chunking)
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNoneチャンクなし。有効にする場合は16程度を推奨。",
)
parser.add_argument(
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--split_attn",
action="store_true",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = HunyuanImageNetworkTrainer()
trainer.train(args)

260
library/attention.py Normal file
View File

@@ -0,0 +1,260 @@
# Unified attention function supporting various implementations
from dataclasses import dataclass
import torch
from typing import Optional, Union
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.flash_attn_interface import flash_attn_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
flash_attn_func = None
try:
from sageattention import sageattn_varlen, sageattn
except ImportError:
sageattn_varlen = None
sageattn = None
try:
import xformers.ops as xops
except ImportError:
xops = None
@dataclass
class AttentionParams:
attn_mode: Optional[str] = None
split_attn: bool = False
img_len: Optional[int] = None
attention_mask: Optional[torch.Tensor] = None
seqlens: Optional[torch.Tensor] = None
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@staticmethod
def create_attention_params_from_mask(
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
) -> "AttentionParams":
if attention_mask is None:
# No attention mask provided: assume all tokens are valid
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
else:
# Note: attention_mask is only for text tokens, not including image tokens
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
max_seqlen = attention_mask.shape[1] + img_len
if split_attn:
# cu_seqlens is not needed for split attention
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
# Convert attention mask to cumulative sequence lengths for flash attention
batch_size = attention_mask.shape[0]
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
for i in range(batch_size):
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
# Expand attention mask to include image tokens
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
if attn_mode == "xformers":
seqlens_list = seqlens.cpu().tolist()
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
seqlens_list, seqlens_list, device=attention_mask.device
)
elif attn_mode == "torch":
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
def attention(
qkv_or_q: Union[torch.Tensor, list],
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
attn_params: Optional[AttentionParams] = None,
drop_rate: float = 0.0,
) -> torch.Tensor:
"""
Compute scaled dot-product attention with variable sequence lengths.
Handles batches with different sequence lengths by splitting and
processing each sequence individually.
Args:
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
attn_param: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:
Attention output tensor [B, L, H*D].
"""
if isinstance(qkv_or_q, list):
q, k, v = qkv_or_q
q: torch.Tensor = q
qkv_or_q.clear()
del qkv_or_q
else:
q: torch.Tensor = qkv_or_q
del qkv_or_q
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
if attn_params is None:
attn_params = AttentionParams.create_attention_params("torch", False)
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
seqlen_trimmed = False
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
seqlen = attn_params.seqlens[0].item()
q = q[:, :seqlen]
k = k[:, :seqlen]
v = v[:, :seqlen]
max_seqlen = attn_params.max_seqlen
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
seqlen_trimmed = True
# Determine tensor layout based on attention implementation
if attn_params.attn_mode == "torch" or (
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
):
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
else:
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
# Process each batch element with its valid sequence lengths
if attn_params.split_attn:
if attn_params.seqlens is None:
# If no seqlens provided, assume all tokens are valid
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
attn_params.max_seqlen = q.shape[1]
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
else:
q = transpose_fn(q)
k = transpose_fn(k)
v = transpose_fn(v)
if attn_params.attn_mode == "torch":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "xformers":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "sageattn":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = sageattn(q, k, v) # B, L, H, D. No dropout support
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
x = sageattn_varlen(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
elif attn_params.attn_mode == "flash":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
x = flash_attn_varlen_func(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
else:
# Currently only PyTorch SDPA and xformers are implemented
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
x = transpose_fn(x) # [B, L, H, D]
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
if seqlen_trimmed:
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
return x

View File

@@ -1,13 +1,28 @@
from concurrent.futures import ThreadPoolExecutor
import gc
import time
from typing import Optional, Union, Callable, Tuple
from typing import Any, Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
def _clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
def synchronize_device(device: torch.device):
def _synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
@@ -71,19 +86,18 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device(device)
_synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device(device)
_synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -122,7 +136,7 @@ class Offloader:
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
@@ -146,33 +160,51 @@ class Offloader:
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
def __init__(
self,
blocks: Union[list[nn.Module], nn.ModuleList],
blocks_to_swap: int,
device: torch.device,
supports_backward: bool = True,
debug: bool = False,
):
super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
self.supports_backward = supports_backward
self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
if self.supports_backward:
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
def set_forward_only(self, forward_only: bool):
self.forward_only = forward_only
def __del__(self):
for handle in self.remove_handles:
handle.remove()
if self.supports_backward:
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
def create_backward_hook(
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -203,18 +235,18 @@ class ModelOffloader(Offloader):
return
if self.debug:
print("Prepare block devices before forward")
print(f"Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
_synchronize_device(self.device)
_clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
@@ -222,10 +254,85 @@ class ModelOffloader(Offloader):
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
# check if blocks_to_swap is enabled
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
# if backward is enabled, we do not swap blocks in forward pass more than blocks_to_swap, because it should be on GPU
if not self.forward_only and block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
# this works for forward-only offloading. move upstream blocks to cuda
block_idx_to_cuda = block_idx_to_cuda % self.num_blocks
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
# endregion
# region cpu offload utils
def to_device(x: Any, device: torch.device) -> Any:
if isinstance(x, torch.Tensor):
return x.to(device)
elif isinstance(x, list):
return [to_device(elem, device) for elem in x]
elif isinstance(x, tuple):
return tuple(to_device(elem, device) for elem in x)
elif isinstance(x, dict):
return {k: to_device(v, device) for k, v in x.items()}
else:
return x
def to_cpu(x: Any) -> Any:
"""
Recursively moves torch.Tensor objects (and containers thereof) to CPU.
Args:
x: A torch.Tensor, or a (possibly nested) list, tuple, or dict containing tensors.
Returns:
The same structure as x, with all torch.Tensor objects moved to CPU.
Non-tensor objects are returned unchanged.
"""
if isinstance(x, torch.Tensor):
return x.cpu()
elif isinstance(x, list):
return [to_cpu(elem) for elem in x]
elif isinstance(x, tuple):
return tuple(to_cpu(elem) for elem in x)
elif isinstance(x, dict):
return {k: to_cpu(v) for k, v in x.items()}
else:
return x
def create_cpu_offloading_wrapper(func: Callable, device: torch.device) -> Callable:
"""
Create a wrapper function that offloads inputs to CPU before calling the original function
and moves outputs back to the specified device.
Args:
func: The original function to wrap.
device: The device to move outputs back to.
Returns:
A wrapped function that offloads inputs to CPU and moves outputs back to the specified device.
"""
def wrapper(orig_func: Callable) -> Callable:
def custom_forward(*inputs):
nonlocal device, orig_func
cuda_inputs = to_device(inputs, device)
outputs = orig_func(*cuda_inputs)
return to_cpu(outputs)
return custom_forward
return wrapper(func)
# endregion

View File

@@ -1,7 +1,10 @@
import functools
import gc
from typing import Optional, Union
import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
@@ -36,12 +39,15 @@ def clean_memory():
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
@@ -51,6 +57,19 @@ def clean_memory_on_device(device: torch.device):
torch.mps.empty_cache()
def synchronize_device(device: Optional[Union[str, torch.device]]):
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""

View File

@@ -16,10 +16,11 @@ from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
from library.safetensors_utils import mem_eff_save_file
init_ipex()
from .utils import setup_logging, mem_eff_save_file
from .utils import setup_logging
setup_logging()
import logging

View File

@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
from library.safetensors_utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
@@ -124,7 +124,7 @@ def load_flow_model(
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:

View File

@@ -0,0 +1,469 @@
import os
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from tqdm import tqdm
from library.device_utils import clean_memory_on_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
"""
Calculate the maximum representable value in FP8 format.
Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). Only supports E4M3 and E5M2 with sign bit.
Args:
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
sign_bits (int): Number of sign bits (0 or 1)
Returns:
float: Maximum value representable in FP8 format
"""
assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
if exp_bits == 4 and mantissa_bits == 3 and sign_bits == 1:
return torch.finfo(torch.float8_e4m3fn).max
elif exp_bits == 5 and mantissa_bits == 2 and sign_bits == 1:
return torch.finfo(torch.float8_e5m2).max
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits} with sign_bits={sign_bits}")
# The following is a manual calculation method (wrong implementation for E5M2), kept for reference.
"""
# Calculate exponent bias
bias = 2 ** (exp_bits - 1) - 1
# Calculate maximum mantissa value
mantissa_max = 1.0
for i in range(mantissa_bits - 1):
mantissa_max += 2 ** -(i + 1)
# Calculate maximum value
max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
return max_value
"""
def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value):
"""
Quantize a tensor to FP8 format using PyTorch's native FP8 dtype support.
Args:
tensor (torch.Tensor): Tensor to quantize
scale (float or torch.Tensor): Scale factor
fp8_dtype (torch.dtype): Target FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
max_value (float): Maximum representable value in FP8
min_value (float): Minimum representable value in FP8
Returns:
torch.Tensor: Quantized tensor in FP8 format
"""
tensor = tensor.to(torch.float32) # ensure tensor is in float32 for division
# Create scaled tensor
tensor = torch.div(tensor, scale).nan_to_num_(0.0) # handle NaN values, equivalent to nonzero_mask in previous function
# Clamp tensor to range
tensor = tensor.clamp_(min=min_value, max=max_value)
# Convert to FP8 dtype
tensor = tensor.to(fp8_dtype)
return tensor
def optimize_state_dict_with_fp8(
state_dict: dict,
calc_device: Union[str, torch.device],
target_layer_keys: Optional[list[str]] = None,
exclude_layer_keys: Optional[list[str]] = None,
exp_bits: int = 4,
mantissa_bits: int = 3,
move_to_device: bool = False,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
):
"""
Optimize Linear layer weights in a model's state dict to FP8 format. The state dict is modified in-place.
This function is a static version of load_safetensors_with_fp8_optimization without loading from files.
Args:
state_dict (dict): State dict to optimize, replaced in-place
calc_device (str): Device to quantize tensors on
target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
exclude_layer_keys (list, optional): Layer key patterns to exclude
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
move_to_device (bool): Move optimized tensors to the calculating device
Returns:
dict: FP8 optimized state dict
"""
if exp_bits == 4 and mantissa_bits == 3:
fp8_dtype = torch.float8_e4m3fn
elif exp_bits == 5 and mantissa_bits == 2:
fp8_dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
# Calculate FP8 max value
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
min_value = -max_value # this function supports only signed FP8
# Create optimized state dict
optimized_count = 0
# Enumerate tarket keys
target_state_dict_keys = []
for key in state_dict.keys():
# Check if it's a weight key and matches target patterns
is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
is_target = is_target and not is_excluded
if is_target and isinstance(state_dict[key], torch.Tensor):
target_state_dict_keys.append(key)
# Process each key
for key in tqdm(target_state_dict_keys):
value = state_dict[key]
# Save original device and dtype
original_device = value.device
original_dtype = value.dtype
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size)
# Add to state dict using original key for weight and new key for scale
fp8_key = key # Maintain original key
scale_key = key.replace(".weight", ".scale_weight")
if not move_to_device:
quantized_weight = quantized_weight.to(original_device)
# keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model.
scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device)
state_dict[fp8_key] = quantized_weight
state_dict[scale_key] = scale_tensor
optimized_count += 1
if calc_device is not None: # optimized_count % 10 == 0 and
# free memory on calculation device
clean_memory_on_device(calc_device)
logger.info(f"Number of optimized Linear layers: {optimized_count}")
return state_dict
def quantize_weight(
key: str,
tensor: torch.Tensor,
fp8_dtype: torch.dtype,
max_value: float,
min_value: float,
quantization_mode: str = "block",
block_size: int = 64,
):
original_shape = tensor.shape
# Determine quantization mode
if quantization_mode == "block":
if tensor.ndim != 2:
quantization_mode = "tensor" # fallback to per-tensor
else:
out_features, in_features = tensor.shape
if in_features % block_size != 0:
quantization_mode = "channel" # fallback to per-channel
logger.warning(
f"Layer {key} with shape {tensor.shape} is not divisible by block_size {block_size}, fallback to per-channel quantization."
)
else:
num_blocks = in_features // block_size
tensor = tensor.contiguous().view(out_features, num_blocks, block_size) # [out, num_blocks, block_size]
elif quantization_mode == "channel":
if tensor.ndim != 2:
quantization_mode = "tensor" # fallback to per-tensor
# Calculate scale factor (per-tensor or per-output-channel with percentile or max)
# value shape is expected to be [out_features, in_features] for Linear weights
if quantization_mode == "channel" or quantization_mode == "block":
# row-wise percentile to avoid being dominated by outliers
# result shape: [out_features, 1] or [out_features, num_blocks, 1]
scale_dim = 1 if quantization_mode == "channel" else 2
abs_w = torch.abs(tensor)
# shape: [out_features, 1] or [out_features, num_blocks, 1]
row_max = torch.max(abs_w, dim=scale_dim, keepdim=True).values
scale = row_max / max_value
else:
# per-tensor
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
# Quantize weight to FP8 (scale can be scalar or [out,1], broadcasting works)
quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value)
# If block-wise, restore original shape
if quantization_mode == "block":
quantized_weight = quantized_weight.view(original_shape) # restore to original shape [out, in]
return quantized_weight, scale
def load_safetensors_with_fp8_optimization(
model_files: List[str],
calc_device: Union[str, torch.device],
target_layer_keys=None,
exclude_layer_keys=None,
exp_bits=4,
mantissa_bits=3,
move_to_device=False,
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
Args:
model_files (list[str]): List of model files to load
calc_device (str or torch.device): Device to quantize tensors on
target_layer_keys (list, optional): Layer key patterns to target for optimization (None for all Linear layers)
exclude_layer_keys (list, optional): Layer key patterns to exclude from optimization
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
move_to_device (bool): Move optimized tensors to the calculating device
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
Returns:
dict: FP8 optimized state dict
"""
if exp_bits == 4 and mantissa_bits == 3:
fp8_dtype = torch.float8_e4m3fn
elif exp_bits == 5 and mantissa_bits == 2:
fp8_dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
# Calculate FP8 max value
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
min_value = -max_value # this function supports only signed FP8
# Define function to determine if a key is a target key. target means fp8 optimization, not for weight hook.
def is_target_key(key):
# Check if weight key matches target patterns and does not match exclude patterns
is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
return is_target and not is_excluded
# Create optimized state dict
optimized_count = 0
# Process each file
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
# Save original device
original_device = value.device # usually cpu
if weight_hook is not None:
# Apply weight hook if provided
value = weight_hook(key, value, keep_on_calc_device=(calc_device is not None))
if not is_target_key(key):
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
# Add to state dict using original key for weight and new key for scale
fp8_key = key # Maintain original key
scale_key = key.replace(".weight", ".scale_weight")
assert fp8_key != scale_key, "FP8 key and scale key must be different"
if not move_to_device:
quantized_weight = quantized_weight.to(original_device)
# keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model.
scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device)
state_dict[fp8_key] = quantized_weight
state_dict[scale_key] = scale_tensor
optimized_count += 1
if calc_device is not None and optimized_count % 10 == 0:
# free memory on calculation device
clean_memory_on_device(calc_device)
logger.info(f"Number of optimized Linear layers: {optimized_count}")
return state_dict
def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
"""
Patched forward method for Linear layers with FP8 weights.
Args:
self: Linear layer instance
x (torch.Tensor): Input tensor
use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
Returns:
torch.Tensor: Result of linear transformation
"""
if use_scaled_mm:
# **not tested**
# _scaled_mm only works for per-tensor scale for now (per-channel scale does not work in certain cases)
if self.scale_weight.ndim != 1:
raise ValueError("scaled_mm only supports per-tensor scale_weight for now.")
input_dtype = x.dtype
original_weight_dtype = self.scale_weight.dtype
target_dtype = self.weight.dtype
# assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
if max_value is None:
# no input quantization
scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
else:
# calculate scale factor for input tensor
scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
# quantize input tensor to FP8: this seems to consume a lot of memory
fp8_max_value = torch.finfo(target_dtype).max
fp8_min_value = torch.finfo(target_dtype).min
x = quantize_fp8(x, scale_x, target_dtype, fp8_max_value, fp8_min_value)
original_shape = x.shape
x = x.reshape(-1, x.shape[-1]).to(target_dtype)
weight = self.weight.t()
scale_weight = self.scale_weight.to(torch.float32)
if self.bias is not None:
# float32 is not supported with bias in scaled_mm
o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:
# Dequantize the weight
original_dtype = self.scale_weight.dtype
if self.scale_weight.ndim < 3:
# per-tensor or per-channel quantization, we can broadcast
dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
else:
# block-wise quantization, need to reshape weight to match scale shape for broadcasting
out_features, num_blocks, _ = self.scale_weight.shape
dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_features, num_blocks, -1)
dequantized_weight = dequantized_weight * self.scale_weight
dequantized_weight = dequantized_weight.view(self.weight.shape)
# Perform linear transformation
if self.bias is not None:
output = F.linear(x, dequantized_weight, self.bias)
else:
output = F.linear(x, dequantized_weight)
return output
def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
"""
Apply monkey patching to a model using FP8 optimized state dict.
Args:
model (nn.Module): Model instance to patch
optimized_state_dict (dict): FP8 optimized state dict
use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
Returns:
nn.Module: The patched model (same instance, modified in-place)
"""
# # Calculate FP8 float8_e5m2 max value
# max_value = calculate_fp8_maxval(5, 2)
max_value = None # do not quantize input tensor
# Find all scale keys to identify FP8-optimized layers
scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
# Enumerate patched layers
patched_module_paths = set()
scale_shape_info = {}
for scale_key in scale_keys:
# Extract module path from scale key (remove .scale_weight)
module_path = scale_key.rsplit(".scale_weight", 1)[0]
patched_module_paths.add(module_path)
# Store scale shape information
scale_shape_info[module_path] = optimized_state_dict[scale_key].shape
patched_count = 0
# Apply monkey patch to each layer with FP8 weights
for name, module in model.named_modules():
# Check if this module has a corresponding scale_weight
has_scale = name in patched_module_paths
# Apply patch if it's a Linear layer with FP8 scale
if isinstance(module, nn.Linear) and has_scale:
# register the scale_weight as a buffer to load the state_dict
# module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
scale_shape = scale_shape_info[name]
module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype))
# Create a new forward method with the patched version.
def new_forward(self, x):
return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
# Bind method to module
module.forward = new_forward.__get__(module, type(module))
patched_count += 1
logger.info(f"Number of monkey-patched Linear layers: {patched_count}")
return model

View File

@@ -0,0 +1,489 @@
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
# Re-implemented for license compliance for sd-scripts.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from library import custom_offloading_utils
from library.attention import AttentionParams
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library.hunyuan_image_modules import (
SingleTokenRefiner,
ByT5Mapper,
PatchEmbed2D,
TimestepEmbedder,
MMDoubleStreamBlock,
MMSingleStreamBlock,
FinalLayer,
)
from library.hunyuan_image_utils import get_nd_rotary_pos_embed
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
# FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "_emb"] # , "modulation"
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_emb"] # , "modulation", "_mod"
# full exclude 24.2GB
# norm and _emb 19.7GB
# fp8 cast 19.7GB
# region DiT Model
class HYImageDiffusionTransformer(nn.Module):
"""
HunyuanImage-2.1 Diffusion Transformer.
A multimodal transformer for image generation with text conditioning,
featuring separate double-stream and single-stream processing blocks.
Args:
attn_mode: Attention implementation mode ("torch" or "sageattn").
"""
def __init__(self, attn_mode: str = "torch", split_attn: bool = False):
super().__init__()
# Fixed architecture parameters for HunyuanImage-2.1
self.patch_size = [1, 1] # 1x1 patch size (no spatial downsampling)
self.in_channels = 64 # Input latent channels
self.out_channels = 64 # Output latent channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = False # Guidance embedding disabled
self.rope_dim_list = [64, 64] # RoPE dimensions for 2D positional encoding
self.rope_theta = 256 # RoPE frequency scaling
self.use_attention_mask = True
self.text_projection = "single_refiner"
self.hidden_size = 3584 # Model dimension
self.heads_num = 28 # Number of attention heads
# Architecture configuration
mm_double_blocks_depth = 20 # Double-stream transformer blocks
mm_single_blocks_depth = 40 # Single-stream transformer blocks
mlp_width_ratio = 4 # MLP expansion ratio
text_states_dim = 3584 # Text encoder output dimension
guidance_embed = False # No guidance embedding
# Layer configuration
mlp_act_type: str = "gelu_tanh" # MLP activation function
qkv_bias: bool = True # Use bias in QKV projections
qk_norm: bool = True # Apply QK normalization
qk_norm_type: str = "rms" # RMS normalization type
self.attn_mode = attn_mode
self.split_attn = split_attn
# ByT5 character-level text encoder mapping
self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False)
# Image latent patch embedding
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size)
# Text token refinement with cross-attention
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2)
# Timestep embedding for diffusion process
self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU)
# MeanFlow not supported in this implementation
self.time_r_in = None
# Guidance embedding (disabled for non-distilled model)
self.guidance_in = TimestepEmbedder(self.hidden_size, nn.SiLU) if guidance_embed else None
# Double-stream blocks: separate image and text processing
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
)
for _ in range(mm_double_blocks_depth)
]
)
# Single-stream blocks: joint processing of concatenated features
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels, nn.SiLU)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
print(f"HunyuanImage-2.1: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
print("HunyuanImage-2.1: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward
)
# , debug=True
print(
f"HunyuanImage-2.1: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def switch_block_swap_for_inference(self):
if self.blocks_to_swap:
self.offloader_double.set_forward_only(True)
self.offloader_single.set_forward_only(True)
self.prepare_block_swap_before_forward()
print(f"HunyuanImage-2.1: Block swap set to forward only.")
def switch_block_swap_for_training(self):
if self.blocks_to_swap:
self.offloader_double.set_forward_only(False)
self.offloader_single.set_forward_only(False)
self.prepare_block_swap_before_forward()
print(f"HunyuanImage-2.1: Block swap set to forward and backward.")
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def get_rotary_pos_embed(self, rope_sizes):
"""
Generate 2D rotary position embeddings for image tokens.
Args:
rope_sizes: Tuple of (height, width) for spatial dimensions.
Returns:
Tuple of (freqs_cos, freqs_sin) tensors for rotary position encoding.
"""
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(self.rope_dim_list, rope_sizes, theta=self.rope_theta)
return freqs_cos, freqs_sin
def reorder_txt_token(
self, byt5_txt: torch.Tensor, txt: torch.Tensor, byt5_text_mask: torch.Tensor, text_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, list[int]]:
"""
Combine and reorder ByT5 character-level and word-level text embeddings.
Concatenates valid tokens from both encoders and creates appropriate masks.
Args:
byt5_txt: ByT5 character-level embeddings [B, L1, D].
txt: Word-level text embeddings [B, L2, D].
byt5_text_mask: Valid token mask for ByT5 [B, L1].
text_mask: Valid token mask for word tokens [B, L2].
Returns:
Tuple of (reordered_embeddings, combined_mask, sequence_lengths).
"""
# Process each batch element separately to handle variable sequence lengths
reorder_txt = []
reorder_mask = []
txt_lens = []
for i in range(text_mask.shape[0]):
byt5_text_mask_i = byt5_text_mask[i].bool()
text_mask_i = text_mask[i].bool()
byt5_text_length = byt5_text_mask_i.sum()
text_length = text_mask_i.sum()
assert byt5_text_length == byt5_text_mask_i[:byt5_text_length].sum()
assert text_length == text_mask_i[:text_length].sum()
byt5_txt_i = byt5_txt[i]
txt_i = txt[i]
reorder_txt_i = torch.cat(
[byt5_txt_i[:byt5_text_length], txt_i[:text_length], byt5_txt_i[byt5_text_length:], txt_i[text_length:]], dim=0
)
reorder_mask_i = torch.zeros(
byt5_text_mask_i.shape[0] + text_mask_i.shape[0], dtype=torch.bool, device=byt5_text_mask_i.device
)
reorder_mask_i[: byt5_text_length + text_length] = True
reorder_txt.append(reorder_txt_i)
reorder_mask.append(reorder_mask_i)
txt_lens.append(byt5_text_length + text_length)
reorder_txt = torch.stack(reorder_txt)
reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)
return reorder_txt, reorder_mask, txt_lens
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
text_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
byt5_text_states: Optional[torch.Tensor] = None,
byt5_text_mask: Optional[torch.Tensor] = None,
rotary_pos_emb_cache: Optional[Dict[Tuple[int, int], Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Forward pass through the HunyuanImage diffusion transformer.
Args:
hidden_states: Input image latents [B, C, H, W].
timestep: Diffusion timestep [B].
text_states: Word-level text embeddings [B, L, D].
encoder_attention_mask: Text attention mask [B, L].
byt5_text_states: ByT5 character-level embeddings [B, L_byt5, D_byt5].
byt5_text_mask: ByT5 attention mask [B, L_byt5].
Returns:
Tuple of (denoised_image, spatial_shape).
"""
img = x = hidden_states
text_mask = encoder_attention_mask
t = timestep
txt = text_states
# Calculate spatial dimensions for rotary position embeddings
_, _, oh, ow = x.shape
th, tw = oh, ow # Height and width (patch_size=[1,1] means no spatial downsampling)
if rotary_pos_emb_cache is not None:
if (th, tw) in rotary_pos_emb_cache:
freqs_cis = rotary_pos_emb_cache[(th, tw)]
freqs_cis = (freqs_cis[0].to(img.device), freqs_cis[1].to(img.device))
else:
freqs_cis = self.get_rotary_pos_embed((th, tw))
rotary_pos_emb_cache[(th, tw)] = (freqs_cis[0].cpu(), freqs_cis[1].cpu())
else:
freqs_cis = self.get_rotary_pos_embed((th, tw))
# Reshape image latents to sequence format: [B, C, H, W] -> [B, H*W, C]
img = self.img_in(img)
# Generate timestep conditioning vector
vec = self.time_in(t)
# MeanFlow and guidance embedding not used in this configuration
# Process text tokens through refinement layers
txt_attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, 0, text_mask)
txt = self.txt_in(txt, t, txt_attn_params)
# Integrate character-level ByT5 features with word-level tokens
# Use variable length sequences with sequence lengths
byt5_txt = self.byt5_in(byt5_text_states)
txt, text_mask, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
# Trim sequences to maximum length in the batch
img_seq_len = img.shape[1]
max_txt_len = max(txt_lens)
txt = txt[:, :max_txt_len, :]
text_mask = text_mask[:, :max_txt_len]
attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, img_seq_len, text_mask)
input_device = img.device
# Process through double-stream blocks (separate image/text attention)
for index, block in enumerate(self.double_blocks):
if self.blocks_to_swap:
self.offloader_double.wait_for_block(index)
img, txt = block(img, txt, vec, freqs_cis, attn_params)
if self.blocks_to_swap:
self.offloader_double.submit_move_blocks(self.double_blocks, index)
# Concatenate image and text tokens for joint processing
x = torch.cat((img, txt), 1)
# Process through single-stream blocks (joint attention)
for index, block in enumerate(self.single_blocks):
if self.blocks_to_swap:
self.offloader_single.wait_for_block(index)
x = block(x, vec, freqs_cis, attn_params)
if self.blocks_to_swap:
self.offloader_single.submit_move_blocks(self.single_blocks, index)
x = x.to(input_device)
vec = vec.to(input_device)
img = x[:, :img_seq_len, ...]
del x
# Apply final projection to output space
img = self.final_layer(img, vec)
del vec
# Reshape from sequence to spatial format: [B, L, C] -> [B, C, H, W]
img = self.unpatchify_2d(img, th, tw)
return img
def unpatchify_2d(self, x, h, w):
"""
Convert sequence format back to spatial image format.
Args:
x: Input tensor [B, H*W, C].
h: Height dimension.
w: Width dimension.
Returns:
Spatial tensor [B, C, H, W].
"""
c = self.unpatchify_channels
x = x.reshape(shape=(x.shape[0], h, w, c))
imgs = x.permute(0, 3, 1, 2)
return imgs
# endregion
# region Model Utils
def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer:
with init_empty_weights():
model = HYImageDiffusionTransformer(attn_mode=attn_mode, split_attn=split_attn)
if dtype is not None:
model.to(dtype)
return model
def load_hunyuan_image_model(
device: Union[str, torch.device],
dit_path: str,
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> HYImageDiffusionTransformer:
"""
Load a HunyuanImage model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
device = torch.device(device)
loading_device = torch.device(loading_device)
model = create_model(attn_mode, split_attn, dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
info = model.load_state_dict(sd, strict=True, assign=True)
logger.info(f"Loaded DiT model from {dit_path}, info={info}")
return model
# endregion

View File

@@ -0,0 +1,863 @@
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
# Re-implemented for license compliance for sd-scripts.
from typing import Tuple, Callable
import torch
import torch.nn as nn
from einops import rearrange
from library import custom_offloading_utils
from library.attention import AttentionParams, attention
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
from library.attention import attention
# region Modules
class ByT5Mapper(nn.Module):
"""
Maps ByT5 character-level encoder outputs to transformer hidden space.
Applies layer normalization, two MLP layers with GELU activation,
and optional residual connection.
Args:
in_dim: Input dimension from ByT5 encoder (1472 for ByT5-large).
out_dim: Intermediate dimension after first projection.
hidden_dim: Hidden dimension for MLP layer.
out_dim1: Final output dimension matching transformer hidden size.
use_residual: Whether to add residual connection (requires in_dim == out_dim).
"""
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.fc3 = nn.Linear(out_dim, out_dim1)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
"""
Transform ByT5 embeddings to transformer space.
Args:
x: Input ByT5 embeddings [..., in_dim].
Returns:
Transformed embeddings [..., out_dim1].
"""
residual = x if self.use_residual else None
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.act_fn(x)
x = self.fc3(x)
if self.use_residual:
x = x + residual
return x
class PatchEmbed2D(nn.Module):
"""
2D patch embedding layer for converting image latents to transformer tokens.
Uses 2D convolution to project image patches to embedding space.
For HunyuanImage-2.1, patch_size=[1,1] means no spatial downsampling.
Args:
patch_size: Spatial size of patches (int or tuple).
in_chans: Number of input channels.
embed_dim: Output embedding dimension.
"""
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.patch_size = tuple(patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=True)
self.norm = nn.Identity() # No normalization layer used
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar diffusion timesteps into vector representations.
Uses sinusoidal encoding followed by a two-layer MLP.
Args:
hidden_size: Output embedding dimension.
act_layer: Activation function class (e.g., nn.SiLU).
frequency_embedding_size: Dimension of sinusoidal encoding.
max_period: Maximum period for sinusoidal frequencies.
out_size: Output dimension (defaults to hidden_size).
"""
def __init__(self, hidden_size, act_layer, frequency_embedding_size=256, max_period=10000, out_size=None):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True), act_layer(), nn.Linear(hidden_size, out_size, bias=True)
)
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
return self.mlp(t_freq)
class TextProjection(nn.Module):
"""
Projects text embeddings through a two-layer MLP.
Used for context-aware representation computation in token refinement.
Args:
in_channels: Input feature dimension.
hidden_size: Hidden and output dimension.
act_layer: Activation function class.
"""
def __init__(self, in_channels, hidden_size, act_layer):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class MLP(nn.Module):
"""
Multi-layer perceptron with configurable activation and normalization.
Standard two-layer MLP with optional dropout and intermediate normalization.
Args:
in_channels: Input feature dimension.
hidden_channels: Hidden layer dimension (defaults to in_channels).
out_features: Output dimension (defaults to in_channels).
act_layer: Activation function class.
norm_layer: Optional normalization layer class.
bias: Whether to use bias (can be bool or tuple for each layer).
drop: Dropout rate (can be float or tuple for each layer).
use_conv: Whether to use convolution instead of linear (not supported).
"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
assert not use_conv, "Convolutional MLP not supported in this implementation."
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = _to_tuple(bias, 2)
drop_probs = _to_tuple(drop, 2)
self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_channels, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class IndividualTokenRefinerBlock(nn.Module):
"""
Single transformer block for individual token refinement.
Applies self-attention and MLP with adaptive layer normalization (AdaLN)
conditioned on timestep and context information.
Args:
hidden_size: Model dimension.
heads_num: Number of attention heads.
mlp_width_ratio: MLP expansion ratio.
mlp_drop_rate: MLP dropout rate.
act_type: Activation function (only "silu" supported).
qk_norm: QK normalization flag (must be False).
qk_norm_type: QK normalization type (only "layer" supported).
qkv_bias: Use bias in QKV projections.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
):
super().__init__()
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
assert act_type == "silu", "Only SiLU activation supported."
assert not qk_norm, "QK normalization must be disabled."
self.heads_num = heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.self_attn_q_norm = nn.Identity()
self.self_attn_k_norm = nn.Identity()
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=nn.SiLU, drop=mlp_drop_rate)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
)
def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor:
"""
Apply self-attention and MLP with adaptive conditioning.
Args:
x: Input token embeddings [B, L, C].
c: Combined conditioning vector [B, C].
attn_params: Attention parameters including sequence lengths.
Returns:
Refined token embeddings [B, L, C].
"""
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
del norm_x
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
del qkv
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, attn_params=attn_params)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
class IndividualTokenRefiner(nn.Module):
"""
Stack of token refinement blocks with self-attention.
Processes tokens individually with adaptive layer normalization.
Args:
hidden_size: Model dimension.
heads_num: Number of attention heads.
depth: Number of refinement blocks.
mlp_width_ratio: MLP expansion ratio.
mlp_drop_rate: MLP dropout rate.
act_type: Activation function type.
qk_norm: QK normalization flag.
qk_norm_type: QK normalization type.
qkv_bias: Use bias in QKV projections.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
depth: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
):
super().__init__()
self.blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
)
for _ in range(depth)
]
)
def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
"""
Apply sequential token refinement.
Args:
x: Input token embeddings [B, L, C].
c: Combined conditioning vector [B, C].
attn_params: Attention parameters including sequence lengths.
Returns:
Refined token embeddings [B, L, C].
"""
for block in self.blocks:
x = block(x, c, attn_params)
return x
class SingleTokenRefiner(nn.Module):
"""
Text embedding refinement with timestep and context conditioning.
Projects input text embeddings and applies self-attention refinement
conditioned on diffusion timestep and aggregate text context.
Args:
in_channels: Input text embedding dimension.
hidden_size: Transformer hidden dimension.
heads_num: Number of attention heads.
depth: Number of refinement blocks.
"""
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int):
# Fixed architecture parameters for HunyuanImage-2.1
mlp_drop_rate: float = 0.0 # No MLP dropout
act_type: str = "silu" # SiLU activation
mlp_width_ratio: float = 4.0 # 4x MLP expansion
qk_norm: bool = False # No QK normalization
qk_norm_type: str = "layer" # Layer norm type (unused)
qkv_bias: bool = True # Use QKV bias
super().__init__()
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True)
act_layer = nn.SiLU
self.t_embedder = TimestepEmbedder(hidden_size, act_layer)
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer)
self.individual_token_refiner = IndividualTokenRefiner(
hidden_size=hidden_size,
heads_num=heads_num,
depth=depth,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
)
def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
"""
Refine text embeddings with timestep conditioning.
Args:
x: Input text embeddings [B, L, in_channels].
t: Diffusion timestep [B].
attn_params: Attention parameters including sequence lengths.
Returns:
Refined embeddings [B, L, hidden_size].
"""
timestep_aware_representations = self.t_embedder(t)
# Compute context-aware representations by averaging valid tokens
txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
del timestep_aware_representations, context_aware_representations
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, attn_params)
return x
class FinalLayer(nn.Module):
"""
Final output projection layer with adaptive layer normalization.
Projects transformer hidden states to output patch space with
timestep-conditioned modulation.
Args:
hidden_size: Input hidden dimension.
patch_size: Spatial patch size for output reshaping.
out_channels: Number of output channels.
act_layer: Activation function class.
"""
def __init__(self, hidden_size, patch_size, out_channels, act_layer):
super().__init__()
# Layer normalization without learnable parameters
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
out_size = (patch_size[0] * patch_size[1]) * out_channels
self.linear = nn.Linear(hidden_size, out_size, bias=True)
# Adaptive layer normalization modulation
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
del shift, scale, c
x = self.linear(x)
return x
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Normalizes input using RMS and applies learnable scaling.
More efficient than LayerNorm as it doesn't compute mean.
Args:
dim: Input feature dimension.
eps: Small value for numerical stability.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply RMS normalization.
Args:
x: Input tensor.
Returns:
RMS normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def reset_parameters(self):
self.weight.fill_(1)
def forward(self, x):
"""
Apply RMSNorm with learnable scaling.
Args:
x: Input tensor.
Returns:
Normalized and scaled tensor.
"""
output = self._norm(x.float()).type_as(x)
del x
# output = output * self.weight
# fp8 support
output = output * self.weight.to(output.dtype)
return output
# kept for reference, not used in current implementation
# class LinearWarpforSingle(nn.Module):
# """
# Linear layer wrapper for concatenating and projecting two inputs.
# Used in single-stream blocks to combine attention output with MLP features.
# Args:
# in_dim: Input dimension (sum of both input feature dimensions).
# out_dim: Output dimension.
# bias: Whether to use bias in linear projection.
# """
# def __init__(self, in_dim: int, out_dim: int, bias=False):
# super().__init__()
# self.fc = nn.Linear(in_dim, out_dim, bias=bias)
# def forward(self, x, y):
# """Concatenate inputs along feature dimension and project."""
# x = torch.cat([x.contiguous(), y.contiguous()], dim=2).contiguous()
# return self.fc(x)
class ModulateDiT(nn.Module):
"""
Timestep conditioning modulation layer.
Projects timestep embeddings to multiple modulation parameters
for adaptive layer normalization.
Args:
hidden_size: Input conditioning dimension.
factor: Number of modulation parameters to generate.
act_layer: Activation function class.
"""
def __init__(self, hidden_size: int, factor: int, act_layer: Callable):
super().__init__()
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.act(x))
class MMDoubleStreamBlock(nn.Module):
"""
Multimodal double-stream transformer block.
Processes image and text tokens separately with cross-modal attention.
Each stream has its own normalization and MLP layers but shares
attention computation for cross-modal interaction.
Args:
hidden_size: Model dimension.
heads_num: Number of attention heads.
mlp_width_ratio: MLP expansion ratio.
mlp_act_type: MLP activation function (only "gelu_tanh" supported).
qk_norm: QK normalization flag (must be True).
qk_norm_type: QK normalization type (only "rms" supported).
qkv_bias: Use bias in QKV projections.
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
):
super().__init__()
assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported."
assert qk_norm_type == "rms", "Only RMS normalization supported."
assert qk_norm, "QK normalization must be enabled."
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
# Image stream processing components
self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6)
self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
# Text stream processing components
self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6)
self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Extract modulation parameters for image and text streams
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
6, dim=-1
)
(txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
6, dim=-1
)
# Process image stream for attention
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
del img_mod1_shift, img_mod1_scale
img_qkv = self.img_attn_qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.chunk(3, dim=-1)
del img_qkv
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num)
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num)
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply rotary position embeddings to image tokens
if freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
del freqs_cis
# Process text stream for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1)
del txt_qkv
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num)
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num)
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Concatenate image and text tokens for joint attention
img_seq_len = img.shape[1]
q = torch.cat([img_q, txt_q], dim=1)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1)
del img_v, txt_v
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, attn_params=attn_params)
del qkv
# Split attention outputs back to separate streams
img_attn, txt_attn = (attn[:, :img_seq_len].contiguous(), attn[:, img_seq_len:].contiguous())
del attn
# Apply attention projection and residual connection for image stream
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
del img_attn, img_mod1_gate
# Apply MLP and residual connection for image stream
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
del img_mod2_shift, img_mod2_scale, img_mod2_gate
# Apply attention projection and residual connection for text stream
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
del txt_attn, txt_mod1_gate
# Apply MLP and residual connection for text stream
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
del txt_mod2_shift, txt_mod2_scale, txt_mod2_gate
return img, txt
def forward(
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.gradient_checkpointing and self.training:
forward_fn = self._forward
if self.cpu_offload_checkpointing:
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False)
else:
return self._forward(img, txt, vec, freqs_cis, attn_params)
class MMSingleStreamBlock(nn.Module):
"""
Multimodal single-stream transformer block.
Processes concatenated image and text tokens jointly with shared attention.
Uses parallel linear layers for efficiency and applies RoPE only to image tokens.
Args:
hidden_size: Model dimension.
heads_num: Number of attention heads.
mlp_width_ratio: MLP expansion ratio.
mlp_act_type: MLP activation function (only "gelu_tanh" supported).
qk_norm: QK normalization flag (must be True).
qk_norm_type: QK normalization type (only "rms" supported).
qk_scale: Attention scaling factor (computed automatically if None).
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
):
super().__init__()
assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported."
assert qk_norm_type == "rms", "Only RMS normalization supported."
assert qk_norm, "QK normalization must be enabled."
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim**-0.5
# Parallel linear projections for efficiency
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)
# Combined output projection
# self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True) # for reference
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, bias=True)
# QK normalization layers
self.q_norm = RMSNorm(head_dim, eps=1e-6)
self.k_norm = RMSNorm(head_dim, eps=1e-6)
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=nn.SiLU)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
attn_params: AttentionParams = None,
) -> torch.Tensor:
# Extract modulation parameters
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
# Compute Q, K, V, and MLP input
qkv_mlp = self.linear1(x_mod)
del x_mod
q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1)
del qkv_mlp
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num)
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num)
# Apply QK-Norm if enabled
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Separate image and text tokens
img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :]
del q
img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :]
del k
# Apply rotary position embeddings only to image tokens
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
del freqs_cis
# Recombine and compute joint attention
q = torch.cat([img_q, txt_q], dim=1)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1)
del img_k, txt_k
# v = torch.cat([img_v, txt_v], dim=1)
# del img_v, txt_v
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, attn_params=attn_params)
del qkv
# Combine attention and MLP outputs, apply gating
# output = self.linear2(attn, self.mlp_act(mlp))
mlp = self.mlp_act(mlp)
output = torch.cat([attn, mlp], dim=2).contiguous()
del attn, mlp
output = self.linear2(output)
return x + apply_gate(output, gate=mod_gate)
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
attn_params: AttentionParams = None,
) -> torch.Tensor:
if self.gradient_checkpointing and self.training:
forward_fn = self._forward
if self.cpu_offload_checkpointing:
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False)
else:
return self._forward(x, vec, freqs_cis, attn_params)
# endregion

View File

@@ -0,0 +1,661 @@
import json
import re
from typing import Tuple, Optional, Union
import torch
from transformers import (
AutoTokenizer,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2Tokenizer,
T5ForConditionalGeneration,
T5Config,
T5Tokenizer,
)
from transformers.models.t5.modeling_t5 import T5Stack
from accelerate import init_empty_weights
from library.safetensors_utils import load_safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BYT5_TOKENIZER_PATH = "google/byt5-small"
QWEN_2_5_VL_IMAGE_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
# Copy from Glyph-SDXL-V2
COLOR_IDX_JSON = """{"white": 0, "black": 1, "darkslategray": 2, "dimgray": 3, "darkolivegreen": 4, "midnightblue": 5, "saddlebrown": 6, "sienna": 7, "whitesmoke": 8, "darkslateblue": 9,
"indianred": 10, "linen": 11, "maroon": 12, "khaki": 13, "sandybrown": 14, "gray": 15, "gainsboro": 16, "teal": 17, "peru": 18, "gold": 19,
"snow": 20, "firebrick": 21, "crimson": 22, "chocolate": 23, "tomato": 24, "brown": 25, "goldenrod": 26, "antiquewhite": 27, "rosybrown": 28, "steelblue": 29,
"floralwhite": 30, "seashell": 31, "darkgreen": 32, "oldlace": 33, "darkkhaki": 34, "burlywood": 35, "red": 36, "darkgray": 37, "orange": 38, "royalblue": 39,
"seagreen": 40, "lightgray": 41, "tan": 42, "coral": 43, "beige": 44, "palevioletred": 45, "wheat": 46, "lavender": 47, "darkcyan": 48, "slateblue": 49,
"slategray": 50, "orangered": 51, "silver": 52, "olivedrab": 53, "forestgreen": 54, "darkgoldenrod": 55, "ivory": 56, "darkorange": 57, "yellow": 58, "hotpink": 59,
"ghostwhite": 60, "lightcoral": 61, "indigo": 62, "bisque": 63, "darkred": 64, "darksalmon": 65, "lightslategray": 66, "dodgerblue": 67, "lightpink": 68, "mistyrose": 69,
"mediumvioletred": 70, "cadetblue": 71, "deeppink": 72, "salmon": 73, "palegoldenrod": 74, "blanchedalmond": 75, "lightseagreen": 76, "cornflowerblue": 77, "yellowgreen": 78, "greenyellow": 79,
"navajowhite": 80, "papayawhip": 81, "mediumslateblue": 82, "purple": 83, "blueviolet": 84, "pink": 85, "cornsilk": 86, "lightsalmon": 87, "mediumpurple": 88, "moccasin": 89,
"turquoise": 90, "mediumseagreen": 91, "lavenderblush": 92, "mediumblue": 93, "darkseagreen": 94, "mediumturquoise": 95, "paleturquoise": 96, "skyblue": 97, "lemonchiffon": 98, "olive": 99,
"peachpuff": 100, "lightyellow": 101, "lightsteelblue": 102, "mediumorchid": 103, "plum": 104, "darkturquoise": 105, "aliceblue": 106, "mediumaquamarine": 107, "orchid": 108, "powderblue": 109,
"blue": 110, "darkorchid": 111, "violet": 112, "lightskyblue": 113, "lightcyan": 114, "lightgoldenrodyellow": 115, "navy": 116, "thistle": 117, "honeydew": 118, "mintcream": 119,
"lightblue": 120, "darkblue": 121, "darkmagenta": 122, "deepskyblue": 123, "magenta": 124, "limegreen": 125, "darkviolet": 126, "cyan": 127, "palegreen": 128, "aquamarine": 129,
"lawngreen": 130, "lightgreen": 131, "azure": 132, "chartreuse": 133, "green": 134, "mediumspringgreen": 135, "lime": 136, "springgreen": 137}"""
MULTILINGUAL_10_LANG_IDX_JSON = """{"en-Montserrat-Regular": 0, "en-Poppins-Italic": 1, "en-GlacialIndifference-Regular": 2, "en-OpenSans-ExtraBoldItalic": 3, "en-Montserrat-Bold": 4, "en-Now-Regular": 5, "en-Garet-Regular": 6, "en-LeagueSpartan-Bold": 7, "en-DMSans-Regular": 8, "en-OpenSauceOne-Regular": 9,
"en-OpenSans-ExtraBold": 10, "en-KGPrimaryPenmanship": 11, "en-Anton-Regular": 12, "en-Aileron-BlackItalic": 13, "en-Quicksand-Light": 14, "en-Roboto-BoldItalic": 15, "en-TheSeasons-It": 16, "en-Kollektif": 17, "en-Inter-BoldItalic": 18, "en-Poppins-Medium": 19,
"en-Poppins-Light": 20, "en-RoxboroughCF-RegularItalic": 21, "en-PlayfairDisplay-SemiBold": 22, "en-Agrandir-Italic": 23, "en-Lato-Regular": 24, "en-MoreSugarRegular": 25, "en-CanvaSans-RegularItalic": 26, "en-PublicSans-Italic": 27, "en-CodePro-NormalLC": 28, "en-Belleza-Regular": 29,
"en-JosefinSans-Bold": 30, "en-HKGrotesk-Bold": 31, "en-Telegraf-Medium": 32, "en-BrittanySignatureRegular": 33, "en-Raleway-ExtraBoldItalic": 34, "en-Mont-RegularItalic": 35, "en-Arimo-BoldItalic": 36, "en-Lora-Italic": 37, "en-ArchivoBlack-Regular": 38, "en-Poppins": 39,
"en-Barlow-Black": 40, "en-CormorantGaramond-Bold": 41, "en-LibreBaskerville-Regular": 42, "en-CanvaSchoolFontRegular": 43, "en-BebasNeueBold": 44, "en-LazydogRegular": 45, "en-FredokaOne-Regular": 46, "en-Horizon-Bold": 47, "en-Nourd-Regular": 48, "en-Hatton-Regular": 49,
"en-Nunito-ExtraBoldItalic": 50, "en-CerebriSans-Regular": 51, "en-Montserrat-Light": 52, "en-TenorSans": 53, "en-Norwester-Regular": 54, "en-ClearSans-Bold": 55, "en-Cardo-Regular": 56, "en-Alice-Regular": 57, "en-Oswald-Regular": 58, "en-Gaegu-Bold": 59,
"en-Muli-Black": 60, "en-TAN-PEARL-Regular": 61, "en-CooperHewitt-Book": 62, "en-Agrandir-Grand": 63, "en-BlackMango-Thin": 64, "en-DMSerifDisplay-Regular": 65, "en-Antonio-Bold": 66, "en-Sniglet-Regular": 67, "en-BeVietnam-Regular": 68, "en-NunitoSans10pt-BlackItalic": 69,
"en-AbhayaLibre-ExtraBold": 70, "en-Rubik-Regular": 71, "en-PPNeueMachina-Regular": 72, "en-TAN - MON CHERI-Regular": 73, "en-Jua-Regular": 74, "en-Playlist-Script": 75, "en-SourceSansPro-BoldItalic": 76, "en-MoonTime-Regular": 77, "en-Eczar-ExtraBold": 78, "en-Gatwick-Regular": 79,
"en-MonumentExtended-Regular": 80, "en-BarlowSemiCondensed-Regular": 81, "en-BarlowCondensed-Regular": 82, "en-Alegreya-Regular": 83, "en-DreamAvenue": 84, "en-RobotoCondensed-Italic": 85, "en-BobbyJones-Regular": 86, "en-Garet-ExtraBold": 87, "en-YesevaOne-Regular": 88, "en-Dosis-ExtraBold": 89,
"en-LeagueGothic-Regular": 90, "en-OpenSans-Italic": 91, "en-TANAEGEAN-Regular": 92, "en-Maharlika-Regular": 93, "en-MarykateRegular": 94, "en-Cinzel-Regular": 95, "en-Agrandir-Wide": 96, "en-Chewy-Regular": 97, "en-BodoniFLF-BoldItalic": 98, "en-Nunito-BlackItalic": 99,
"en-LilitaOne": 100, "en-HandyCasualCondensed-Regular": 101, "en-Ovo": 102, "en-Livvic-Regular": 103, "en-Agrandir-Narrow": 104, "en-CrimsonPro-Italic": 105, "en-AnonymousPro-Bold": 106, "en-NF-OneLittleFont-Bold": 107, "en-RedHatDisplay-BoldItalic": 108, "en-CodecPro-Regular": 109,
"en-HalimunRegular": 110, "en-LibreFranklin-Black": 111, "en-TeXGyreTermes-BoldItalic": 112, "en-Shrikhand-Regular": 113, "en-TTNormsPro-Italic": 114, "en-Gagalin-Regular": 115, "en-OpenSans-Bold": 116, "en-GreatVibes-Regular": 117, "en-Breathing": 118, "en-HeroLight-Regular": 119,
"en-KGPrimaryDots": 120, "en-Quicksand-Bold": 121, "en-Brice-ExtraLightSemiExpanded": 122, "en-Lato-BoldItalic": 123, "en-Fraunces9pt-Italic": 124, "en-AbrilFatface-Regular": 125, "en-BerkshireSwash-Regular": 126, "en-Atma-Bold": 127, "en-HolidayRegular": 128, "en-BebasNeueCyrillic": 129,
"en-IntroRust-Base": 130, "en-Gistesy": 131, "en-BDScript-Regular": 132, "en-ApricotsRegular": 133, "en-Prompt-Black": 134, "en-TAN MERINGUE": 135, "en-Sukar Regular": 136, "en-GentySans-Regular": 137, "en-NeueEinstellung-Normal": 138, "en-Garet-Bold": 139,
"en-FiraSans-Black": 140, "en-BantayogLight": 141, "en-NotoSerifDisplay-Black": 142, "en-TTChocolates-Regular": 143, "en-Ubuntu-Regular": 144, "en-Assistant-Bold": 145, "en-ABeeZee-Regular": 146, "en-LexendDeca-Regular": 147, "en-KingredSerif": 148, "en-Radley-Regular": 149,
"en-BrownSugar": 150, "en-MigraItalic-ExtraboldItalic": 151, "en-ChildosArabic-Regular": 152, "en-PeaceSans": 153, "en-LondrinaSolid-Black": 154, "en-SpaceMono-BoldItalic": 155, "en-RobotoMono-Light": 156, "en-CourierPrime-Regular": 157, "en-Alata-Regular": 158, "en-Amsterdam-One": 159,
"en-IreneFlorentina-Regular": 160, "en-CatchyMager": 161, "en-Alta_regular": 162, "en-ArticulatCF-Regular": 163, "en-Raleway-Regular": 164, "en-BrasikaDisplay": 165, "en-TANAngleton-Italic": 166, "en-NotoSerifDisplay-ExtraCondensedItalic": 167, "en-Bryndan Write": 168, "en-TTCommonsPro-It": 169,
"en-AlexBrush-Regular": 170, "en-Antic-Regular": 171, "en-TTHoves-Bold": 172, "en-DroidSerif": 173, "en-AblationRegular": 174, "en-Marcellus-Regular": 175, "en-Sanchez-Italic": 176, "en-JosefinSans": 177, "en-Afrah-Regular": 178, "en-PinyonScript": 179,
"en-TTInterphases-BoldItalic": 180, "en-Yellowtail-Regular": 181, "en-Gliker-Regular": 182, "en-BobbyJonesSoft-Regular": 183, "en-IBMPlexSans": 184, "en-Amsterdam-Three": 185, "en-Amsterdam-FourSlant": 186, "en-TTFors-Regular": 187, "en-Quattrocento": 188, "en-Sifonn-Basic": 189,
"en-AlegreyaSans-Black": 190, "en-Daydream": 191, "en-AristotelicaProTx-Rg": 192, "en-NotoSerif": 193, "en-EBGaramond-Italic": 194, "en-HammersmithOne-Regular": 195, "en-RobotoSlab-Regular": 196, "en-DO-Sans-Regular": 197, "en-KGPrimaryDotsLined": 198, "en-Blinker-Regular": 199,
"en-TAN NIMBUS": 200, "en-Blueberry-Regular": 201, "en-Rosario-Regular": 202, "en-Forum": 203, "en-MistrullyRegular": 204, "en-SourceSerifPro-Regular": 205, "en-Bugaki-Regular": 206, "en-CMUSerif-Roman": 207, "en-GulfsDisplay-NormalItalic": 208, "en-PTSans-Bold": 209,
"en-Sensei-Medium": 210, "en-SquadaOne-Regular": 211, "en-Arapey-Italic": 212, "en-Parisienne-Regular": 213, "en-Aleo-Italic": 214, "en-QuicheDisplay-Italic": 215, "en-RocaOne-It": 216, "en-Funtastic-Regular": 217, "en-PTSerif-BoldItalic": 218, "en-Muller-RegularItalic": 219,
"en-ArgentCF-Regular": 220, "en-Brightwall-Italic": 221, "en-Knewave-Regular": 222, "en-TYSerif-D": 223, "en-Agrandir-Tight": 224, "en-AlfaSlabOne-Regular": 225, "en-TANTangkiwood-Display": 226, "en-Kief-Montaser-Regular": 227, "en-Gotham-Book": 228, "en-JuliusSansOne-Regular": 229,
"en-CocoGothic-Italic": 230, "en-SairaCondensed-Regular": 231, "en-DellaRespira-Regular": 232, "en-Questrial-Regular": 233, "en-BukhariScript-Regular": 234, "en-HelveticaWorld-Bold": 235, "en-TANKINDRED-Display": 236, "en-CinzelDecorative-Regular": 237, "en-Vidaloka-Regular": 238, "en-AlegreyaSansSC-Black": 239,
"en-FeelingPassionate-Regular": 240, "en-QuincyCF-Regular": 241, "en-FiraCode-Regular": 242, "en-Genty-Regular": 243, "en-Nickainley-Normal": 244, "en-RubikOne-Regular": 245, "en-Gidole-Regular": 246, "en-Borsok": 247, "en-Gordita-RegularItalic": 248, "en-Scripter-Regular": 249,
"en-Buffalo-Regular": 250, "en-KleinText-Regular": 251, "en-Creepster-Regular": 252, "en-Arvo-Bold": 253, "en-GabrielSans-NormalItalic": 254, "en-Heebo-Black": 255, "en-LexendExa-Regular": 256, "en-BrixtonSansTC-Regular": 257, "en-GildaDisplay-Regular": 258, "en-ChunkFive-Roman": 259,
"en-Amaranth-BoldItalic": 260, "en-BubbleboddyNeue-Regular": 261, "en-MavenPro-Bold": 262, "en-TTDrugs-Italic": 263, "en-CyGrotesk-KeyRegular": 264, "en-VarelaRound-Regular": 265, "en-Ruda-Black": 266, "en-SafiraMarch": 267, "en-BloggerSans": 268, "en-TANHEADLINE-Regular": 269,
"en-SloopScriptPro-Regular": 270, "en-NeueMontreal-Regular": 271, "en-Schoolbell-Regular": 272, "en-SigherRegular": 273, "en-InriaSerif-Regular": 274, "en-JetBrainsMono-Regular": 275, "en-MADEEvolveSans": 276, "en-Dekko": 277, "en-Handyman-Regular": 278, "en-Aileron-BoldItalic": 279,
"en-Bright-Italic": 280, "en-Solway-Regular": 281, "en-Higuen-Regular": 282, "en-WedgesItalic": 283, "en-TANASHFORD-BOLD": 284, "en-IBMPlexMono": 285, "en-RacingSansOne-Regular": 286, "en-RegularBrush": 287, "en-OpenSans-LightItalic": 288, "en-SpecialElite-Regular": 289,
"en-FuturaLTPro-Medium": 290, "en-MaragsaDisplay": 291, "en-BigShouldersDisplay-Regular": 292, "en-BDSans-Regular": 293, "en-RasputinRegular": 294, "en-Yvesyvesdrawing-BoldItalic": 295, "en-Bitter-Regular": 296, "en-LuckiestGuy-Regular": 297, "en-CanvaSchoolFontDotted": 298, "en-TTFirsNeue-Italic": 299,
"en-Sunday-Regular": 300, "en-HKGothic-MediumItalic": 301, "en-CaveatBrush-Regular": 302, "en-HeliosExt": 303, "en-ArchitectsDaughter-Regular": 304, "en-Angelina": 305, "en-Calistoga-Regular": 306, "en-ArchivoNarrow-Regular": 307, "en-ObjectSans-MediumSlanted": 308, "en-AyrLucidityCondensed-Regular": 309,
"en-Nexa-RegularItalic": 310, "en-Lustria-Regular": 311, "en-Amsterdam-TwoSlant": 312, "en-Virtual-Regular": 313, "en-Brusher-Regular": 314, "en-NF-Lepetitcochon-Regular": 315, "en-TANTWINKLE": 316, "en-LeJour-Serif": 317, "en-Prata-Regular": 318, "en-PPWoodland-Regular": 319,
"en-PlayfairDisplay-BoldItalic": 320, "en-AmaticSC-Regular": 321, "en-Cabin-Regular": 322, "en-Manjari-Bold": 323, "en-MrDafoe-Regular": 324, "en-TTRamillas-Italic": 325, "en-Luckybones-Bold": 326, "en-DarkerGrotesque-Light": 327, "en-BellabooRegular": 328, "en-CormorantSC-Bold": 329,
"en-GochiHand-Regular": 330, "en-Atteron": 331, "en-RocaTwo-Lt": 332, "en-ZCOOLXiaoWei-Regular": 333, "en-TANSONGBIRD": 334, "en-HeadingNow-74Regular": 335, "en-Luthier-BoldItalic": 336, "en-Oregano-Regular": 337, "en-AyrTropikaIsland-Int": 338, "en-Mali-Regular": 339,
"en-DidactGothic-Regular": 340, "en-Lovelace-Regular": 341, "en-BakerieSmooth-Regular": 342, "en-CarterOne": 343, "en-HussarBd": 344, "en-OldStandard-Italic": 345, "en-TAN-ASTORIA-Display": 346, "en-rugratssans-Regular": 347, "en-BMHANNA": 348, "en-BetterSaturday": 349,
"en-AdigianaToybox": 350, "en-Sailors": 351, "en-PlayfairDisplaySC-Italic": 352, "en-Etna-Regular": 353, "en-Revive80Signature": 354, "en-CAGenerated": 355, "en-Poppins-Regular": 356, "en-Jonathan-Regular": 357, "en-Pacifico-Regular": 358, "en-Saira-Black": 359,
"en-Loubag-Regular": 360, "en-Decalotype-Black": 361, "en-Mansalva-Regular": 362, "en-Allura-Regular": 363, "en-ProximaNova-Bold": 364, "en-TANMIGNON-DISPLAY": 365, "en-ArsenicaAntiqua-Regular": 366, "en-BreulGroteskA-RegularItalic": 367, "en-HKModular-Bold": 368, "en-TANNightingale-Regular": 369,
"en-AristotelicaProCndTxt-Rg": 370, "en-Aprila-Regular": 371, "en-Tomorrow-Regular": 372, "en-AngellaWhite": 373, "en-KaushanScript-Regular": 374, "en-NotoSans": 375, "en-LeJour-Script": 376, "en-BrixtonTC-Regular": 377, "en-OleoScript-Regular": 378, "en-Cakerolli-Regular": 379,
"en-Lobster-Regular": 380, "en-FrunchySerif-Regular": 381, "en-PorcelainRegular": 382, "en-AlojaExtended": 383, "en-SergioTrendy-Italic": 384, "en-LovelaceText-Bold": 385, "en-Anaktoria": 386, "en-JimmyScript-Light": 387, "en-IBMPlexSerif": 388, "en-Marta": 389,
"en-Mango-Regular": 390, "en-Overpass-Italic": 391, "en-Hagrid-Regular": 392, "en-ElikaGorica": 393, "en-Amiko-Regular": 394, "en-EFCOBrookshire-Regular": 395, "en-Caladea-Regular": 396, "en-MoonlightBold": 397, "en-Staatliches-Regular": 398, "en-Helios-Bold": 399,
"en-Satisfy-Regular": 400, "en-NexaScript-Regular": 401, "en-Trocchi-Regular": 402, "en-March": 403, "en-IbarraRealNova-Regular": 404, "en-Nectarine-Regular": 405, "en-Overpass-Light": 406, "en-TruetypewriterPolyglOTT": 407, "en-Bangers-Regular": 408, "en-Lazord-BoldExpandedItalic": 409,
"en-Chloe-Regular": 410, "en-BaskervilleDisplayPT-Regular": 411, "en-Bright-Regular": 412, "en-Vollkorn-Regular": 413, "en-Harmattan": 414, "en-SortsMillGoudy-Regular": 415, "en-Biryani-Bold": 416, "en-SugoProDisplay-Italic": 417, "en-Lazord-BoldItalic": 418, "en-Alike-Regular": 419,
"en-PermanentMarker-Regular": 420, "en-Sacramento-Regular": 421, "en-HKGroteskPro-Italic": 422, "en-Aleo-BoldItalic": 423, "en-Noot": 424, "en-TANGARLAND-Regular": 425, "en-Twister": 426, "en-Arsenal-Italic": 427, "en-Bogart-Italic": 428, "en-BethEllen-Regular": 429,
"en-Caveat-Regular": 430, "en-BalsamiqSans-Bold": 431, "en-BreeSerif-Regular": 432, "en-CodecPro-ExtraBold": 433, "en-Pierson-Light": 434, "en-CyGrotesk-WideRegular": 435, "en-Lumios-Marker": 436, "en-Comfortaa-Bold": 437, "en-TraceFontRegular": 438, "en-RTL-AdamScript-Regular": 439,
"en-EastmanGrotesque-Italic": 440, "en-Kalam-Bold": 441, "en-ChauPhilomeneOne-Regular": 442, "en-Coiny-Regular": 443, "en-Lovera": 444, "en-Gellatio": 445, "en-TitilliumWeb-Bold": 446, "en-OilvareBase-Italic": 447, "en-Catamaran-Black": 448, "en-Anteb-Italic": 449,
"en-SueEllenFrancisco": 450, "en-SweetApricot": 451, "en-BrightSunshine": 452, "en-IM_FELL_Double_Pica_Italic": 453, "en-Granaina-limpia": 454, "en-TANPARFAIT": 455, "en-AcherusGrotesque-Regular": 456, "en-AwesomeLathusca-Italic": 457, "en-Signika-Bold": 458, "en-Andasia": 459,
"en-DO-AllCaps-Slanted": 460, "en-Zenaida-Regular": 461, "en-Fahkwang-Regular": 462, "en-Play-Regular": 463, "en-BERNIERRegular-Regular": 464, "en-PlumaThin-Regular": 465, "en-SportsWorld": 466, "en-Garet-Black": 467, "en-CarolloPlayscript-BlackItalic": 468, "en-Cheque-Regular": 469,
"en-SEGO": 470, "en-BobbyJones-Condensed": 471, "en-NexaSlab-RegularItalic": 472, "en-DancingScript-Regular": 473, "en-PaalalabasDisplayWideBETA": 474, "en-Magnolia-Script": 475, "en-OpunMai-400It": 476, "en-MadelynFill-Regular": 477, "en-ZingRust-Base": 478, "en-FingerPaint-Regular": 479,
"en-BostonAngel-Light": 480, "en-Gliker-RegularExpanded": 481, "en-Ahsing": 482, "en-Engagement-Regular": 483, "en-EyesomeScript": 484, "en-LibraSerifModern-Regular": 485, "en-London-Regular": 486, "en-AtkinsonHyperlegible-Regular": 487, "en-StadioNow-TextItalic": 488, "en-Aniyah": 489,
"en-ITCAvantGardePro-Bold": 490, "en-Comica-Regular": 491, "en-Coustard-Regular": 492, "en-Brice-BoldCondensed": 493, "en-TANNEWYORK-Bold": 494, "en-TANBUSTER-Bold": 495, "en-Alatsi-Regular": 496, "en-TYSerif-Book": 497, "en-Jingleberry": 498, "en-Rajdhani-Bold": 499,
"en-LobsterTwo-BoldItalic": 500, "en-BestLight-Medium": 501, "en-Hitchcut-Regular": 502, "en-GermaniaOne-Regular": 503, "en-Emitha-Script": 504, "en-LemonTuesday": 505, "en-Cubao_Free_Regular": 506, "en-MonterchiSerif-Regular": 507, "en-AllertaStencil-Regular": 508, "en-RTL-Sondos-Regular": 509,
"en-HomemadeApple-Regular": 510, "en-CosmicOcto-Medium": 511, "cn-HelloFont-FangHuaTi": 0, "cn-HelloFont-ID-DianFangSong-Bold": 1, "cn-HelloFont-ID-DianFangSong": 2, "cn-HelloFont-ID-DianHei-CEJ": 3, "cn-HelloFont-ID-DianHei-DEJ": 4, "cn-HelloFont-ID-DianHei-EEJ": 5, "cn-HelloFont-ID-DianHei-FEJ": 6, "cn-HelloFont-ID-DianHei-GEJ": 7, "cn-HelloFont-ID-DianKai-Bold": 8, "cn-HelloFont-ID-DianKai": 9,
"cn-HelloFont-WenYiHei": 10, "cn-Hellofont-ID-ChenYanXingKai": 11, "cn-Hellofont-ID-DaZiBao": 12, "cn-Hellofont-ID-DaoCaoRen": 13, "cn-Hellofont-ID-JianSong": 14, "cn-Hellofont-ID-JiangHuZhaoPaiHei": 15, "cn-Hellofont-ID-KeSong": 16, "cn-Hellofont-ID-LeYuanTi": 17, "cn-Hellofont-ID-Pinocchio": 18, "cn-Hellofont-ID-QiMiaoTi": 19,
"cn-Hellofont-ID-QingHuaKai": 20, "cn-Hellofont-ID-QingHuaXingKai": 21, "cn-Hellofont-ID-ShanShuiXingKai": 22, "cn-Hellofont-ID-ShouXieQiShu": 23, "cn-Hellofont-ID-ShouXieTongZhenTi": 24, "cn-Hellofont-ID-TengLingTi": 25, "cn-Hellofont-ID-XiaoLiShu": 26, "cn-Hellofont-ID-XuanZhenSong": 27, "cn-Hellofont-ID-ZhongLingXingKai": 28, "cn-HellofontIDJiaoTangTi": 29,
"cn-HellofontIDJiuZhuTi": 30, "cn-HuXiaoBao-SaoBao": 31, "cn-HuXiaoBo-NanShen": 32, "cn-HuXiaoBo-ZhenShuai": 33, "cn-SourceHanSansSC-Bold": 34, "cn-SourceHanSansSC-ExtraLight": 35, "cn-SourceHanSansSC-Heavy": 36, "cn-SourceHanSansSC-Light": 37, "cn-SourceHanSansSC-Medium": 38, "cn-SourceHanSansSC-Normal": 39,
"cn-SourceHanSansSC-Regular": 40, "cn-SourceHanSerifSC-Bold": 41, "cn-SourceHanSerifSC-ExtraLight": 42, "cn-SourceHanSerifSC-Heavy": 43, "cn-SourceHanSerifSC-Light": 44, "cn-SourceHanSerifSC-Medium": 45, "cn-SourceHanSerifSC-Regular": 46, "cn-SourceHanSerifSC-SemiBold": 47, "cn-xiaowei": 48, "cn-AaJianHaoTi": 49,
"cn-AlibabaPuHuiTi-Bold": 50, "cn-AlibabaPuHuiTi-Heavy": 51, "cn-AlibabaPuHuiTi-Light": 52, "cn-AlibabaPuHuiTi-Medium": 53, "cn-AlibabaPuHuiTi-Regular": 54, "cn-CanvaAcidBoldSC": 55, "cn-CanvaBreezeCN": 56, "cn-CanvaBumperCropSC": 57, "cn-CanvaCakeShopCN": 58, "cn-CanvaEndeavorBlackSC": 59,
"cn-CanvaJoyHeiCN": 60, "cn-CanvaLiCN": 61, "cn-CanvaOrientalBrushCN": 62, "cn-CanvaPoster": 63, "cn-CanvaQinfuCalligraphyCN": 64, "cn-CanvaSweetHeartCN": 65, "cn-CanvaSwordLikeDreamCN": 66, "cn-CanvaTangyuanHandwritingCN": 67, "cn-CanvaWanderWorldCN": 68, "cn-CanvaWenCN": 69,
"cn-DianZiChunYi": 70, "cn-GenSekiGothicTW-H": 71, "cn-GenWanMinTW-L": 72, "cn-GenYoMinTW-B": 73, "cn-GenYoMinTW-EL": 74, "cn-GenYoMinTW-H": 75, "cn-GenYoMinTW-M": 76, "cn-GenYoMinTW-R": 77, "cn-GenYoMinTW-SB": 78, "cn-HYQiHei-AZEJ": 79,
"cn-HYQiHei-EES": 80, "cn-HanaMinA": 81, "cn-HappyZcool-2016": 82, "cn-HelloFont ZJ KeKouKeAiTi": 83, "cn-HelloFont-ID-BoBoTi": 84, "cn-HelloFont-ID-FuGuHei-25": 85, "cn-HelloFont-ID-FuGuHei-35": 86, "cn-HelloFont-ID-FuGuHei-45": 87, "cn-HelloFont-ID-FuGuHei-55": 88, "cn-HelloFont-ID-FuGuHei-65": 89,
"cn-HelloFont-ID-FuGuHei-75": 90, "cn-HelloFont-ID-FuGuHei-85": 91, "cn-HelloFont-ID-HeiKa": 92, "cn-HelloFont-ID-HeiTang": 93, "cn-HelloFont-ID-JianSong-95": 94, "cn-HelloFont-ID-JueJiangHei-50": 95, "cn-HelloFont-ID-JueJiangHei-55": 96, "cn-HelloFont-ID-JueJiangHei-60": 97, "cn-HelloFont-ID-JueJiangHei-65": 98, "cn-HelloFont-ID-JueJiangHei-70": 99,
"cn-HelloFont-ID-JueJiangHei-75": 100, "cn-HelloFont-ID-JueJiangHei-80": 101, "cn-HelloFont-ID-KuHeiTi": 102, "cn-HelloFont-ID-LingDongTi": 103, "cn-HelloFont-ID-LingLiTi": 104, "cn-HelloFont-ID-MuFengTi": 105, "cn-HelloFont-ID-NaiNaiJiangTi": 106, "cn-HelloFont-ID-PangDu": 107, "cn-HelloFont-ID-ReLieTi": 108, "cn-HelloFont-ID-RouRun": 109,
"cn-HelloFont-ID-SaShuangShouXieTi": 110, "cn-HelloFont-ID-WangZheFengFan": 111, "cn-HelloFont-ID-YouQiTi": 112, "cn-Hellofont-ID-XiaLeTi": 113, "cn-Hellofont-ID-XianXiaTi": 114, "cn-HuXiaoBoKuHei": 115, "cn-IDDanMoXingKai": 116, "cn-IDJueJiangHei": 117, "cn-IDMeiLingTi": 118, "cn-IDQQSugar": 119,
"cn-LiuJianMaoCao-Regular": 120, "cn-LongCang-Regular": 121, "cn-MaShanZheng-Regular": 122, "cn-PangMenZhengDao-3": 123, "cn-PangMenZhengDao-Cu": 124, "cn-PangMenZhengDao": 125, "cn-SentyCaramel": 126, "cn-SourceHanSerifSC": 127, "cn-WenCang-Regular": 128, "cn-WenQuanYiMicroHei": 129,
"cn-XianErTi": 130, "cn-YRDZSTJF": 131, "cn-YS-HelloFont-BangBangTi": 132, "cn-ZCOOLKuaiLe-Regular": 133, "cn-ZCOOLQingKeHuangYou-Regular": 134, "cn-ZCOOLXiaoWei-Regular": 135, "cn-ZCOOL_KuHei": 136, "cn-ZhiMangXing-Regular": 137, "cn-baotuxiaobaiti": 138, "cn-jiangxizhuokai-Regular": 139,
"cn-zcool-gdh": 140, "cn-zcoolqingkehuangyouti-Regular": 141, "cn-zcoolwenyiti": 142, "jp-04KanjyukuGothic": 0, "jp-07LightNovelPOP": 1, "jp-07NikumaruFont": 2, "jp-07YasashisaAntique": 3, "jp-07YasashisaGothic": 4, "jp-BokutachinoGothic2Bold": 5, "jp-BokutachinoGothic2Regular": 6, "jp-CHI_SpeedyRight_full_211128-Regular": 7, "jp-CHI_SpeedyRight_italic_full_211127-Regular": 8, "jp-CP-Font": 9,
"jp-Canva_CezanneProN-B": 10, "jp-Canva_CezanneProN-M": 11, "jp-Canva_ChiaroStd-B": 12, "jp-Canva_CometStd-B": 13, "jp-Canva_DotMincho16Std-M": 14, "jp-Canva_GrecoStd-B": 15, "jp-Canva_GrecoStd-M": 16, "jp-Canva_LyraStd-DB": 17, "jp-Canva_MatisseHatsuhiPro-B": 18, "jp-Canva_MatisseHatsuhiPro-M": 19,
"jp-Canva_ModeMinAStd-B": 20, "jp-Canva_NewCezanneProN-B": 21, "jp-Canva_NewCezanneProN-M": 22, "jp-Canva_PearlStd-L": 23, "jp-Canva_RaglanStd-UB": 24, "jp-Canva_RailwayStd-B": 25, "jp-Canva_ReggaeStd-B": 26, "jp-Canva_RocknRollStd-DB": 27, "jp-Canva_RodinCattleyaPro-B": 28, "jp-Canva_RodinCattleyaPro-M": 29,
"jp-Canva_RodinCattleyaPro-UB": 30, "jp-Canva_RodinHimawariPro-B": 31, "jp-Canva_RodinHimawariPro-M": 32, "jp-Canva_RodinMariaPro-B": 33, "jp-Canva_RodinMariaPro-DB": 34, "jp-Canva_RodinProN-M": 35, "jp-Canva_ShadowTLStd-B": 36, "jp-Canva_StickStd-B": 37, "jp-Canva_TsukuAOldMinPr6N-B": 38, "jp-Canva_TsukuAOldMinPr6N-R": 39,
"jp-Canva_UtrilloPro-DB": 40, "jp-Canva_UtrilloPro-M": 41, "jp-Canva_YurukaStd-UB": 42, "jp-FGUIGEN": 43, "jp-GlowSansJ-Condensed-Heavy": 44, "jp-GlowSansJ-Condensed-Light": 45, "jp-GlowSansJ-Normal-Bold": 46, "jp-GlowSansJ-Normal-Light": 47, "jp-HannariMincho": 48, "jp-HarenosoraMincho": 49,
"jp-Jiyucho": 50, "jp-Kaiso-Makina-B": 51, "jp-Kaisotai-Next-UP-B": 52, "jp-KokoroMinchoutai": 53, "jp-Mamelon-3-Hi-Regular": 54, "jp-MotoyaAnemoneStd-W1": 55, "jp-MotoyaAnemoneStd-W5": 56, "jp-MotoyaAnticPro-W3": 57, "jp-MotoyaCedarStd-W3": 58, "jp-MotoyaCedarStd-W5": 59,
"jp-MotoyaGochikaStd-W4": 60, "jp-MotoyaGochikaStd-W8": 61, "jp-MotoyaGothicMiyabiStd-W6": 62, "jp-MotoyaGothicStd-W3": 63, "jp-MotoyaGothicStd-W5": 64, "jp-MotoyaKoinStd-W3": 65, "jp-MotoyaKyotaiStd-W2": 66, "jp-MotoyaKyotaiStd-W4": 67, "jp-MotoyaMaruStd-W3": 68, "jp-MotoyaMaruStd-W5": 69,
"jp-MotoyaMinchoMiyabiStd-W4": 70, "jp-MotoyaMinchoMiyabiStd-W6": 71, "jp-MotoyaMinchoModernStd-W4": 72, "jp-MotoyaMinchoModernStd-W6": 73, "jp-MotoyaMinchoStd-W3": 74, "jp-MotoyaMinchoStd-W5": 75, "jp-MotoyaReisyoStd-W2": 76, "jp-MotoyaReisyoStd-W6": 77, "jp-MotoyaTohitsuStd-W4": 78, "jp-MotoyaTohitsuStd-W6": 79,
"jp-MtySousyokuEmBcJis-W6": 80, "jp-MtySousyokuLiBcJis-W6": 81, "jp-Mushin": 82, "jp-NotoSansJP-Bold": 83, "jp-NotoSansJP-Regular": 84, "jp-NudMotoyaAporoStd-W3": 85, "jp-NudMotoyaAporoStd-W5": 86, "jp-NudMotoyaCedarStd-W3": 87, "jp-NudMotoyaCedarStd-W5": 88, "jp-NudMotoyaMaruStd-W3": 89,
"jp-NudMotoyaMaruStd-W5": 90, "jp-NudMotoyaMinchoStd-W5": 91, "jp-Ounen-mouhitsu": 92, "jp-Ronde-B-Square": 93, "jp-SMotoyaGyosyoStd-W5": 94, "jp-SMotoyaSinkaiStd-W3": 95, "jp-SMotoyaSinkaiStd-W5": 96, "jp-SourceHanSansJP-Bold": 97, "jp-SourceHanSansJP-Regular": 98, "jp-SourceHanSerifJP-Bold": 99,
"jp-SourceHanSerifJP-Regular": 100, "jp-TazuganeGothicStdN-Bold": 101, "jp-TazuganeGothicStdN-Regular": 102, "jp-TelopMinProN-B": 103, "jp-Togalite-Bold": 104, "jp-Togalite-Regular": 105, "jp-TsukuMinPr6N-E": 106, "jp-TsukuMinPr6N-M": 107, "jp-mikachan_o": 108, "jp-nagayama_kai": 109,
"jp-07LogoTypeGothic7": 110, "jp-07TetsubinGothic": 111, "jp-851CHIKARA-DZUYOKU-KANA-A": 112, "jp-ARMinchoJIS-Light": 113, "jp-ARMinchoJIS-Ultra": 114, "jp-ARPCrystalMinchoJIS-Medium": 115, "jp-ARPCrystalRGothicJIS-Medium": 116, "jp-ARShounanShinpitsuGyosyoJIS-Medium": 117, "jp-AozoraMincho-bold": 118, "jp-AozoraMinchoRegular": 119,
"jp-ArialUnicodeMS-Bold": 120, "jp-ArialUnicodeMS": 121, "jp-CanvaBreezeJP": 122, "jp-CanvaLiCN": 123, "jp-CanvaLiJP": 124, "jp-CanvaOrientalBrushCN": 125, "jp-CanvaQinfuCalligraphyJP": 126, "jp-CanvaSweetHeartJP": 127, "jp-CanvaWenJP": 128, "jp-Corporate-Logo-Bold": 129,
"jp-DelaGothicOne-Regular": 130, "jp-GN-Kin-iro_SansSerif": 131, "jp-GN-Koharuiro_Sunray": 132, "jp-GenEiGothicM-B": 133, "jp-GenEiGothicM-R": 134, "jp-GenJyuuGothic-Bold": 135, "jp-GenRyuMinTW-B": 136, "jp-GenRyuMinTW-R": 137, "jp-GenSekiGothicTW-B": 138, "jp-GenSekiGothicTW-R": 139,
"jp-GenSenRoundedTW-B": 140, "jp-GenSenRoundedTW-R": 141, "jp-GenShinGothic-Bold": 142, "jp-GenShinGothic-Normal": 143, "jp-GenWanMinTW-L": 144, "jp-GenYoGothicTW-B": 145, "jp-GenYoGothicTW-R": 146, "jp-GenYoMinTW-B": 147, "jp-GenYoMinTW-R": 148, "jp-HGBouquet": 149,
"jp-HanaMinA": 150, "jp-HanazomeFont": 151, "jp-HinaMincho-Regular": 152, "jp-Honoka-Antique-Maru": 153, "jp-Honoka-Mincho": 154, "jp-HuiFontP": 155, "jp-IPAexMincho": 156, "jp-JK-Gothic-L": 157, "jp-JK-Gothic-M": 158, "jp-JackeyFont": 159,
"jp-KaiseiTokumin-Bold": 160, "jp-KaiseiTokumin-Regular": 161, "jp-Keifont": 162, "jp-KiwiMaru-Regular": 163, "jp-Koku-Mincho-Regular": 164, "jp-MotoyaLMaru-W3-90ms-RKSJ-H": 165, "jp-NewTegomin-Regular": 166, "jp-NicoKaku": 167, "jp-NicoMoji+": 168, "jp-Otsutome_font-Bold": 169,
"jp-PottaOne-Regular": 170, "jp-RampartOne-Regular": 171, "jp-Senobi-Gothic-Bold": 172, "jp-Senobi-Gothic-Regular": 173, "jp-SmartFontUI-Proportional": 174, "jp-SoukouMincho": 175, "jp-TEST_Klee-DB": 176, "jp-TEST_Klee-M": 177, "jp-TEST_UDMincho-B": 178, "jp-TEST_UDMincho-L": 179,
"jp-TT_Akakane-EB": 180, "jp-Tanuki-Permanent-Marker": 181, "jp-TrainOne-Regular": 182, "jp-TsunagiGothic-Black": 183, "jp-Ume-Hy-Gothic": 184, "jp-Ume-P-Mincho": 185, "jp-WenQuanYiMicroHei": 186, "jp-XANO-mincho-U32": 187, "jp-YOzFontM90-Regular": 188, "jp-Yomogi-Regular": 189,
"jp-YujiBoku-Regular": 190, "jp-YujiSyuku-Regular": 191, "jp-ZenKakuGothicNew-Bold": 192, "jp-ZenKakuGothicNew-Regular": 193, "jp-ZenKurenaido-Regular": 194, "jp-ZenMaruGothic-Bold": 195, "jp-ZenMaruGothic-Regular": 196, "jp-darts-font": 197, "jp-irohakakuC-Bold": 198, "jp-irohakakuC-Medium": 199,
"jp-irohakakuC-Regular": 200, "jp-katyou": 201, "jp-mplus-1m-bold": 202, "jp-mplus-1m-regular": 203, "jp-mplus-1p-bold": 204, "jp-mplus-1p-regular": 205, "jp-rounded-mplus-1p-bold": 206, "jp-rounded-mplus-1p-regular": 207, "jp-timemachine-wa": 208, "jp-ttf-GenEiLateMin-Medium": 209,
"jp-uzura_font": 210, "kr-Arita-buri-Bold_OTF": 0, "kr-Arita-buri-HairLine_OTF": 1, "kr-Arita-buri-Light_OTF": 2, "kr-Arita-buri-Medium_OTF": 3, "kr-Arita-buri-SemiBold_OTF": 4, "kr-Canva_YDSunshineL": 5, "kr-Canva_YDSunshineM": 6, "kr-Canva_YoonGulimPro710": 7, "kr-Canva_YoonGulimPro730": 8, "kr-Canva_YoonGulimPro740": 9,
"kr-Canva_YoonGulimPro760": 10, "kr-Canva_YoonGulimPro770": 11, "kr-Canva_YoonGulimPro790": 12, "kr-CreHappB": 13, "kr-CreHappL": 14, "kr-CreHappM": 15, "kr-CreHappS": 16, "kr-OTAuroraB": 17, "kr-OTAuroraL": 18, "kr-OTAuroraR": 19,
"kr-OTDoldamgilB": 20, "kr-OTDoldamgilL": 21, "kr-OTDoldamgilR": 22, "kr-OTHamsterB": 23, "kr-OTHamsterL": 24, "kr-OTHamsterR": 25, "kr-OTHapchangdanB": 26, "kr-OTHapchangdanL": 27, "kr-OTHapchangdanR": 28, "kr-OTSupersizeBkBOX": 29,
"kr-SourceHanSansKR-Bold": 30, "kr-SourceHanSansKR-ExtraLight": 31, "kr-SourceHanSansKR-Heavy": 32, "kr-SourceHanSansKR-Light": 33, "kr-SourceHanSansKR-Medium": 34, "kr-SourceHanSansKR-Normal": 35, "kr-SourceHanSansKR-Regular": 36, "kr-SourceHanSansSC-Bold": 37, "kr-SourceHanSansSC-ExtraLight": 38, "kr-SourceHanSansSC-Heavy": 39,
"kr-SourceHanSansSC-Light": 40, "kr-SourceHanSansSC-Medium": 41, "kr-SourceHanSansSC-Normal": 42, "kr-SourceHanSansSC-Regular": 43, "kr-SourceHanSerifSC-Bold": 44, "kr-SourceHanSerifSC-SemiBold": 45, "kr-TDTDBubbleBubbleOTF": 46, "kr-TDTDConfusionOTF": 47, "kr-TDTDCuteAndCuteOTF": 48, "kr-TDTDEggTakOTF": 49,
"kr-TDTDEmotionalLetterOTF": 50, "kr-TDTDGalapagosOTF": 51, "kr-TDTDHappyHourOTF": 52, "kr-TDTDLatteOTF": 53, "kr-TDTDMoonLightOTF": 54, "kr-TDTDParkForestOTF": 55, "kr-TDTDPencilOTF": 56, "kr-TDTDSmileOTF": 57, "kr-TDTDSproutOTF": 58, "kr-TDTDSunshineOTF": 59,
"kr-TDTDWaferOTF": 60, "kr-777Chyaochyureu": 61, "kr-ArialUnicodeMS-Bold": 62, "kr-ArialUnicodeMS": 63, "kr-BMHANNA": 64, "kr-Baekmuk-Dotum": 65, "kr-BagelFatOne-Regular": 66, "kr-CoreBandi": 67, "kr-CoreBandiFace": 68, "kr-CoreBori": 69,
"kr-DoHyeon-Regular": 70, "kr-Dokdo-Regular": 71, "kr-Gaegu-Bold": 72, "kr-Gaegu-Light": 73, "kr-Gaegu-Regular": 74, "kr-GamjaFlower-Regular": 75, "kr-GasoekOne-Regular": 76, "kr-GothicA1-Black": 77, "kr-GothicA1-Bold": 78, "kr-GothicA1-ExtraBold": 79,
"kr-GothicA1-ExtraLight": 80, "kr-GothicA1-Light": 81, "kr-GothicA1-Medium": 82, "kr-GothicA1-Regular": 83, "kr-GothicA1-SemiBold": 84, "kr-GothicA1-Thin": 85, "kr-Gugi-Regular": 86, "kr-HiMelody-Regular": 87, "kr-Jua-Regular": 88, "kr-KirangHaerang-Regular": 89,
"kr-NanumBrush": 90, "kr-NanumPen": 91, "kr-NanumSquareRoundB": 92, "kr-NanumSquareRoundEB": 93, "kr-NanumSquareRoundL": 94, "kr-NanumSquareRoundR": 95, "kr-SeH-CB": 96, "kr-SeH-CBL": 97, "kr-SeH-CEB": 98, "kr-SeH-CL": 99,
"kr-SeH-CM": 100, "kr-SeN-CB": 101, "kr-SeN-CBL": 102, "kr-SeN-CEB": 103, "kr-SeN-CL": 104, "kr-SeN-CM": 105, "kr-Sunflower-Bold": 106, "kr-Sunflower-Light": 107, "kr-Sunflower-Medium": 108, "kr-TTClaytoyR": 109,
"kr-TTDalpangiR": 110, "kr-TTMamablockR": 111, "kr-TTNauidongmuR": 112, "kr-TTOktapbangR": 113, "kr-UhBeeMiMi": 114, "kr-UhBeeMiMiBold": 115, "kr-UhBeeSe_hyun": 116, "kr-UhBeeSe_hyunBold": 117, "kr-UhBeenamsoyoung": 118, "kr-UhBeenamsoyoungBold": 119,
"kr-WenQuanYiMicroHei": 120, "kr-YeonSung-Regular": 121}"""
def add_special_token(tokenizer: T5Tokenizer, text_encoder: T5Stack):
"""
Add special tokens for color and font to tokenizer and text encoder.
Args:
tokenizer: Huggingface tokenizer.
text_encoder: Huggingface T5 encoder.
"""
idx_font_dict = json.loads(MULTILINGUAL_10_LANG_IDX_JSON)
idx_color_dict = json.loads(COLOR_IDX_JSON)
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
additional_special_tokens = []
additional_special_tokens += color_token
additional_special_tokens += font_token
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
def load_byt5(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> Tuple[T5Stack, T5Tokenizer]:
BYT5_CONFIG_JSON = """
{
"_name_or_path": "/home/patrick/t5/byt5-small",
"architectures": [
"T5ForConditionalGeneration"
],
"d_ff": 3584,
"d_kv": 64,
"d_model": 1472,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"gradient_checkpointing": false,
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 4,
"num_heads": 6,
"num_layers": 12,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"tokenizer_class": "ByT5Tokenizer",
"transformers_version": "4.7.0.dev0",
"use_cache": true,
"vocab_size": 384
}
"""
logger.info(f"Loading BYT5 tokenizer from {BYT5_TOKENIZER_PATH}")
byt5_tokenizer = AutoTokenizer.from_pretrained(BYT5_TOKENIZER_PATH)
logger.info("Initializing BYT5 text encoder")
config = json.loads(BYT5_CONFIG_JSON)
config = T5Config(**config)
with init_empty_weights():
byt5_text_encoder = T5ForConditionalGeneration._from_config(config).get_encoder()
add_special_token(byt5_tokenizer, byt5_text_encoder)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype)
# remove "encoder." prefix
sd = {k[len("encoder.") :] if k.startswith("encoder.") else k: v for k, v in sd.items()}
sd["embed_tokens.weight"] = sd.pop("shared.weight")
info = byt5_text_encoder.load_state_dict(sd, strict=True, assign=True)
byt5_text_encoder.to(device)
byt5_text_encoder.eval()
logger.info(f"BYT5 text encoder loaded with info: {info}")
return byt5_tokenizer, byt5_text_encoder
def load_qwen2_5_vl(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> tuple[Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration]:
QWEN2_5_VL_CONFIG_JSON = """
{
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"text_config": {
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": null,
"initializer_range": 0.02,
"intermediate_size": 18944,
"layer_types": [
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention"
],
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl_text",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": null,
"torch_dtype": "float32",
"use_cache": true,
"use_sliding_window": false,
"video_token_id": null,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
},
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.53.1",
"use_cache": true,
"use_sliding_window": false,
"video_token_id": 151656,
"vision_config": {
"depth": 32,
"fullatt_block_indexes": [
7,
15,
23,
31
],
"hidden_act": "silu",
"hidden_size": 1280,
"in_channels": 3,
"in_chans": 3,
"initializer_range": 0.02,
"intermediate_size": 3420,
"model_type": "qwen2_5_vl",
"num_heads": 16,
"out_hidden_size": 3584,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"tokens_per_second": 2,
"torch_dtype": "float32",
"window_size": 112
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
}
"""
config = json.loads(QWEN2_5_VL_CONFIG_JSON)
config = Qwen2_5_VLConfig(**config)
with init_empty_weights():
qwen2_5_vl = Qwen2_5_VLForConditionalGeneration._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype)
# convert prefixes
for key in list(sd.keys()):
if key.startswith("model."):
new_key = key.replace("model.", "model.language_model.", 1)
elif key.startswith("visual."):
new_key = key.replace("visual.", "model.visual.", 1)
else:
continue
if key not in sd:
logger.warning(f"Key {key} not found in state dict, skipping.")
continue
sd[new_key] = sd.pop(key)
info = qwen2_5_vl.load_state_dict(sd, strict=True, assign=True)
logger.info(f"Loaded Qwen2.5-VL: {info}")
qwen2_5_vl.to(device)
qwen2_5_vl.eval()
if dtype is not None:
if dtype.itemsize == 1: # fp8
org_dtype = torch.bfloat16 # model weight is fp8 in loading, but original dtype is bfloat16
logger.info(f"prepare Qwen2.5-VL for fp8: set to {dtype} from {org_dtype}")
qwen2_5_vl.to(dtype)
# prepare LLM for fp8
def prepare_fp8(vl_model: Qwen2_5_VLForConditionalGeneration, target_dtype):
def forward_hook(module):
def forward(hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
# return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
return (module.weight.to(torch.float32) * hidden_states.to(torch.float32)).to(input_dtype)
return forward
def decoder_forward_hook(module):
def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = module.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = module.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
input_dtype = hidden_states.dtype
hidden_states = residual.to(torch.float32) + hidden_states.to(torch.float32)
hidden_states = hidden_states.to(input_dtype)
# Fully Connected
residual = hidden_states
hidden_states = module.post_attention_layernorm(hidden_states)
hidden_states = module.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
return forward
for module in vl_model.modules():
if module.__class__.__name__ in ["Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["Qwen2RMSNorm"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if module.__class__.__name__ in ["Qwen2_5_VLDecoderLayer"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = decoder_forward_hook(module)
if module.__class__.__name__ in ["Qwen2_5_VisionRotaryEmbedding"]:
# print("set", module.__class__.__name__, "hooks")
module.to(target_dtype)
prepare_fp8(qwen2_5_vl, org_dtype)
else:
logger.info(f"Setting Qwen2.5-VL to dtype: {dtype}")
qwen2_5_vl.to(dtype)
# Load tokenizer
logger.info(f"Loading tokenizer from {QWEN_2_5_VL_IMAGE_ID}")
tokenizer = Qwen2Tokenizer.from_pretrained(QWEN_2_5_VL_IMAGE_ID)
return tokenizer, qwen2_5_vl
TOKENIZER_MAX_LENGTH = 1024
PROMPT_TEMPLATE_ENCODE_START_IDX = 34
def get_qwen_prompt_embeds(
tokenizer: Qwen2Tokenizer, vlm: Qwen2_5_VLForConditionalGeneration, prompt: Union[str, list[str]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
input_ids, mask = get_qwen_tokens(tokenizer, prompt)
return get_qwen_prompt_embeds_from_tokens(vlm, input_ids, mask)
def get_qwen_tokens(tokenizer: Qwen2Tokenizer, prompt: Union[str, list[str]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
tokenizer_max_length = TOKENIZER_MAX_LENGTH
# HunyuanImage-2.1 does not use "<|im_start|>assistant\n" in the prompt template
prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
# \n<|im_start|>assistant\n"
prompt_template_encode_start_idx = PROMPT_TEMPLATE_ENCODE_START_IDX
# default_sample_size = 128
prompt = [prompt] if isinstance(prompt, str) else prompt
template = prompt_template_encode
drop_idx = prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt")
return txt_tokens.input_ids, txt_tokens.attention_mask
def get_qwen_prompt_embeds_from_tokens(
vlm: Qwen2_5_VLForConditionalGeneration, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenizer_max_length = TOKENIZER_MAX_LENGTH
drop_idx = PROMPT_TEMPLATE_ENCODE_START_IDX
device = vlm.device
dtype = vlm.dtype
input_ids = input_ids.to(device=device)
attention_mask = attention_mask.to(device=device)
if dtype.itemsize == 1: # fp8
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
else:
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=dtype, enabled=True):
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
hidden_states = encoder_hidden_states.hidden_states[-3] # use the 3rd last layer's hidden states for HunyuanImage-2.1
if hidden_states.shape[1] > tokenizer_max_length + drop_idx:
logger.warning(f"Hidden states shape {hidden_states.shape} exceeds max length {tokenizer_max_length + drop_idx}")
# --- Unnecessary complicated processing, keep for reference ---
# split_hidden_states = extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
# split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
# max_seq_len = max([e.size(0) for e in split_hidden_states])
# prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
# encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
# ----------------------------------------------------------
prompt_embeds = hidden_states[:, drop_idx:, :]
encoder_attention_mask = attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(device=device)
return prompt_embeds, encoder_attention_mask
def format_prompt(texts, styles):
"""
Text "{text}" in {color}, {type}.
"""
prompt = ""
for text, style in zip(texts, styles):
# color and style are always None in official implementation, so we only use text
text_prompt = f'Text "{text}"'
text_prompt += ". "
prompt = prompt + text_prompt
return prompt
BYT5_MAX_LENGTH = 128
def get_glyph_prompt_embeds(
tokenizer: T5Tokenizer, text_encoder: T5Stack, prompt: Optional[str] = None
) -> Tuple[list[bool], torch.Tensor, torch.Tensor]:
byt5_tokens, byt5_text_mask = get_byt5_text_tokens(tokenizer, prompt)
return get_byt5_prompt_embeds_from_tokens(text_encoder, byt5_tokens, byt5_text_mask)
def get_byt5_prompt_embeds_from_tokens(
text_encoder: T5Stack, byt5_text_ids: Optional[torch.Tensor], byt5_text_mask: Optional[torch.Tensor]
) -> Tuple[list[bool], torch.Tensor, torch.Tensor]:
byt5_max_length = BYT5_MAX_LENGTH
if byt5_text_ids is None or byt5_text_mask is None or byt5_text_mask.sum() == 0:
return (
[False],
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
)
byt5_text_ids = byt5_text_ids.to(device=text_encoder.device)
byt5_text_mask = byt5_text_mask.to(device=text_encoder.device)
with torch.no_grad(), torch.autocast(device_type=text_encoder.device.type, dtype=text_encoder.dtype, enabled=True):
byt5_prompt_embeds = text_encoder(byt5_text_ids, attention_mask=byt5_text_mask.float())
byt5_emb = byt5_prompt_embeds[0]
return [True], byt5_emb, byt5_text_mask
def get_byt5_text_tokens(tokenizer, prompt):
if not prompt:
return None, None
try:
text_prompt_texts = []
# pattern_quote_single = r"\'(.*?)\'"
pattern_quote_double = r"\"(.*?)\""
pattern_quote_chinese_single = r"(.*?)"
pattern_quote_chinese_double = r"“(.*?)”"
# matches_quote_single = re.findall(pattern_quote_single, prompt)
matches_quote_double = re.findall(pattern_quote_double, prompt)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
# text_prompt_texts.extend(matches_quote_single)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if not text_prompt_texts:
return None, None
text_prompt_style_list = [{"color": None, "font-family": None} for _ in range(len(text_prompt_texts))]
glyph_text_formatted = format_prompt(text_prompt_texts, text_prompt_style_list)
logger.info(f"Glyph text formatted: {glyph_text_formatted}")
byt5_text_inputs = tokenizer(
glyph_text_formatted,
padding="max_length",
max_length=BYT5_MAX_LENGTH,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
byt5_text_ids = byt5_text_inputs.input_ids
byt5_text_mask = byt5_text_inputs.attention_mask
return byt5_text_ids, byt5_text_mask
except Exception as e:
logger.warning(f"Warning: Error in glyph encoding, using fallback: {e}")
return None, None

View File

@@ -0,0 +1,525 @@
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
# Re-implemented for license compliance for sd-scripts.
import math
from typing import Tuple, Union, Optional
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MODEL_VERSION_2_1 = "hunyuan-image-2.1"
# region model
def _to_tuple(x, dim=2):
"""
Convert int or sequence to tuple of specified dimension.
Args:
x: Int or sequence to convert.
dim: Target dimension for tuple.
Returns:
Tuple of length dim.
"""
if isinstance(x, int) or isinstance(x, float):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, dim=2):
"""
Generate n-dimensional coordinate meshgrid from 0 to grid_size.
Creates coordinate grids for each spatial dimension, useful for
generating position embeddings.
Args:
start: Grid size for each dimension (int or tuple).
dim: Number of spatial dimensions.
Returns:
Coordinate grid tensor [dim, *grid_size].
"""
# Convert start to grid sizes
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
# Generate coordinate arrays for each dimension
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_nd_rotary_pos_embed(rope_dim_list, start, theta=10000.0):
"""
Generate n-dimensional rotary position embeddings for spatial tokens.
Creates RoPE embeddings for multi-dimensional positional encoding,
distributing head dimensions across spatial dimensions.
Args:
rope_dim_list: Dimensions allocated to each spatial axis (should sum to head_dim).
start: Spatial grid size for each dimension.
theta: Base frequency for RoPE computation.
Returns:
Tuple of (cos_freqs, sin_freqs) for rotary embedding [H*W, D/2].
"""
grid = get_meshgrid_nd(start, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
# Generate RoPE embeddings for each spatial dimension
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
def get_1d_rotary_pos_embed(
dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate 1D rotary position embeddings.
Args:
dim: Embedding dimension (must be even).
pos: Position indices [S] or scalar for sequence length.
theta: Base frequency for sinusoidal encoding.
Returns:
Tuple of (cos_freqs, sin_freqs) tensors [S, D].
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
freqs = torch.outer(pos, freqs) # [S, D/2]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings for diffusion models.
Converts scalar timesteps to high-dimensional embeddings using
sinusoidal encoding at different frequencies.
Args:
t: Timestep tensor [N].
dim: Output embedding dimension.
max_period: Maximum period for frequency computation.
Returns:
Timestep embeddings [N, dim].
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def modulate(x, shift=None, scale=None):
"""
Apply adaptive layer normalization modulation.
Applies scale and shift transformations for conditioning
in adaptive layer normalization.
Args:
x: Input tensor to modulate.
shift: Additive shift parameter (optional).
scale: Multiplicative scale parameter (optional).
Returns:
Modulated tensor x * (1 + scale) + shift.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""
Apply gating mechanism to tensor.
Multiplies input by gate values, optionally applying tanh activation.
Used in residual connections for adaptive control.
Args:
x: Input tensor to gate.
gate: Gating values (optional).
tanh: Whether to apply tanh to gate values.
Returns:
Gated tensor x * gate (with optional tanh).
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
x: torch.Tensor,
head_first=False,
):
"""
Reshape RoPE frequency tensors for broadcasting with attention tensors.
Args:
freqs_cis: Tuple of (cos_freqs, sin_freqs) tensors.
x: Target tensor for broadcasting compatibility.
head_first: Must be False (only supported layout).
Returns:
Reshaped (cos_freqs, sin_freqs) tensors ready for broadcasting.
"""
assert not head_first, "Only head_first=False layout supported."
assert isinstance(freqs_cis, tuple), "Expected tuple of (cos, sin) frequency tensors."
assert x.ndim > 1, f"x should have at least 2 dimensions, but got {x.ndim}"
# Validate frequency tensor dimensions match target tensor
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(x):
"""
Rotate half the dimensions for RoPE computation.
Splits the last dimension in half and applies a 90-degree rotation
by swapping and negating components.
Args:
x: Input tensor [..., D] where D is even.
Returns:
Rotated tensor with same shape as input.
"""
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to query and key tensors.
Args:
xq: Query tensor [B, S, H, D].
xk: Key tensor [B, S, H, D].
freqs_cis: Tuple of (cos_freqs, sin_freqs) for rotation.
head_first: Whether head dimension precedes sequence dimension.
Returns:
Tuple of rotated (query, key) tensors.
"""
device = xq.device
dtype = xq.dtype
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(device), sin.to(device)
# Apply rotation: x' = x * cos + rotate_half(x) * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
return xq_out, xk_out
# endregion
# region inference
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate timesteps and sigmas for diffusion sampling.
Args:
sampling_steps: Number of sampling steps.
shift: Sigma shift parameter for schedule modification.
device: Target device for tensors.
Returns:
Tuple of (timesteps, sigmas) tensors.
"""
sigmas = torch.linspace(1, 0, sampling_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
sigmas = sigmas.to(torch.float32)
timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=device)
return timesteps, sigmas
def step(latents, noise_pred, sigmas, step_i):
"""
Perform a single diffusion sampling step.
Args:
latents: Current latent state.
noise_pred: Predicted noise.
sigmas: Noise schedule sigmas.
step_i: Current step index.
Returns:
Updated latents after the step.
"""
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
# endregion
# region AdaptiveProjectedGuidance
class MomentumBuffer:
"""
Exponential moving average buffer for APG momentum.
"""
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance_apg(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
"""
Apply normalized adaptive projected guidance.
Projects the guidance vector to reduce over-saturation while maintaining
directional control by decomposing into parallel and orthogonal components.
Args:
pred_cond: Conditional prediction.
pred_uncond: Unconditional prediction.
guidance_scale: Guidance scale factor.
momentum_buffer: Optional momentum buffer for temporal smoothing.
eta: Scaling factor for parallel component.
norm_threshold: Maximum norm for guidance vector clipping.
use_original_formulation: Whether to use original APG formulation.
Returns:
Guided prediction tensor.
"""
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))] # All dimensions except batch
# Apply momentum smoothing if available
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
# Apply norm clipping if threshold is set
if norm_threshold > 0:
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(torch.ones_like(diff_norm), norm_threshold / diff_norm)
diff = diff * scale_factor
# Project guidance vector into parallel and orthogonal components
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
# Combine components with different scaling
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
class AdaptiveProjectedGuidance:
"""
Adaptive Projected Guidance for classifier-free guidance.
Implements APG which projects the guidance vector to reduce over-saturation
while maintaining directional control.
"""
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 0.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor:
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
pred = normalized_guidance_apg(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
def rescale_noise_cfg(guided_noise, conditional_noise, rescale_factor=0.0):
"""
Rescale guided noise prediction to prevent overexposure and improve image quality.
This implementation addresses the overexposure issue described in "Common Diffusion Noise
Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf) (Section 3.4).
The rescaling preserves the statistical properties of the conditional prediction while reducing artifacts.
Args:
guided_noise (torch.Tensor): Noise prediction from classifier-free guidance.
conditional_noise (torch.Tensor): Noise prediction from conditional model.
rescale_factor (float): Interpolation factor between original and rescaled predictions.
0.0 = no rescaling, 1.0 = full rescaling.
Returns:
torch.Tensor: Rescaled noise prediction with reduced overexposure.
"""
if rescale_factor == 0.0:
return guided_noise
# Calculate standard deviation across spatial dimensions for both predictions
spatial_dims = list(range(1, conditional_noise.ndim))
conditional_std = conditional_noise.std(dim=spatial_dims, keepdim=True)
guided_std = guided_noise.std(dim=spatial_dims, keepdim=True)
# Rescale guided noise to match conditional noise statistics
std_ratio = conditional_std / guided_std
rescaled_prediction = guided_noise * std_ratio
# Interpolate between original and rescaled predictions
final_prediction = rescale_factor * rescaled_prediction + (1.0 - rescale_factor) * guided_noise
return final_prediction
def apply_classifier_free_guidance(
noise_pred_text: torch.Tensor,
noise_pred_uncond: torch.Tensor,
is_ocr: bool,
guidance_scale: float,
step: int,
apg_start_step_ocr: int = 38,
apg_start_step_general: int = 5,
cfg_guider_ocr: AdaptiveProjectedGuidance = None,
cfg_guider_general: AdaptiveProjectedGuidance = None,
guidance_rescale: float = 0.0,
):
"""
Apply classifier-free guidance with OCR-aware APG for batch_size=1.
Args:
noise_pred_text: Conditional noise prediction tensor [1, ...].
noise_pred_uncond: Unconditional noise prediction tensor [1, ...].
is_ocr: Whether this sample requires OCR-specific guidance.
guidance_scale: Guidance scale for CFG.
step: Current diffusion step index.
apg_start_step_ocr: Step to start APG for OCR regions.
apg_start_step_general: Step to start APG for general regions.
cfg_guider_ocr: APG guider for OCR regions.
cfg_guider_general: APG guider for general regions.
Returns:
Guided noise prediction tensor [1, ...].
"""
if guidance_scale == 1.0:
return noise_pred_text
# Select appropriate guider and start step based on OCR requirement
if is_ocr:
cfg_guider = cfg_guider_ocr
apg_start_step = apg_start_step_ocr
else:
cfg_guider = cfg_guider_general
apg_start_step = apg_start_step_general
# Apply standard CFG or APG based on current step
if step <= apg_start_step:
# Standard classifier-free guidance
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
# Initialize APG guider state
_ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
else:
# Use APG for guidance
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
return noise_pred
# endregion

View File

@@ -0,0 +1,755 @@
from typing import Optional, Tuple
from einops import rearrange
import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import Conv2d
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from library.safetensors_utils import load_safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 32 # 32x spatial compression
LATENT_SCALING_FACTOR = 0.75289 # Latent scaling factor for Hunyuan Image-2.1
def swish(x: Tensor) -> Tensor:
"""Swish activation function: x * sigmoid(x)."""
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
"""Self-attention block using scaled dot-product attention."""
def __init__(self, in_channels: int, chunk_size: Optional[int] = None):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if chunk_size is None or chunk_size <= 0:
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
else:
self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
def attention(self, x: Tensor) -> Tensor:
x = self.norm(x)
q = self.q(x)
k = self.k(x)
v = self.v(x)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c").contiguous()
k = rearrange(k, "b c h w -> b (h w) c").contiguous()
v = rearrange(v, "b c h w -> b (h w) c").contiguous()
x = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ChunkedConv2d(nn.Conv2d):
"""
Convolutional layer that processes input in chunks to reduce memory usage.
Parameters
----------
chunk_size : int, optional
Size of chunks to process at a time. Default is 64.
"""
def __init__(self, *args, **kwargs):
if "chunk_size" in kwargs:
self.chunk_size = kwargs.pop("chunk_size", 64)
super().__init__(*args, **kwargs)
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported."
assert self.groups == 1, "Only groups=1 is supported."
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
assert (
self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2
), "Only kernel_size//2 padding is supported."
self.original_padding = self.padding
self.padding = (0, 0) # We handle padding manually in forward
def forward(self, x: Tensor) -> Tensor:
# If chunking is not needed, process normally. We chunk only along height dimension.
if self.chunk_size is None or x.shape[1] <= self.chunk_size:
self.padding = self.original_padding
x = super().forward(x)
self.padding = (0, 0)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return x
# Process input in chunks to reduce memory usage
org_shape = x.shape
# If kernel size is not 1, we need to use overlapping chunks
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
step = self.chunk_size - overlap
y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device)
yi = 0
i = 0
while i < org_shape[2]:
si = i if i == 0 else i - overlap
ei = i + self.chunk_size
# Check last chunk. If remaining part is small, include it in last chunk
if ei > org_shape[2] or ei + step // 4 > org_shape[2]:
ei = org_shape[2]
chunk = x[:, :, : ei - si, :]
x = x[:, :, ei - si - overlap * 2 :, :]
# Pad chunk if needed: This is as the original Conv2d with padding
if i == 0: # First chunk
# Pad except bottom
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
elif ei == org_shape[2]: # Last chunk
# Pad except top
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
else:
# Pad left and right only
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
chunk = super().forward(chunk)
y[:, :, yi : yi + chunk.shape[2], :] = chunk
yi += chunk.shape[2]
del chunk
if ei == org_shape[2]:
break
i += step
assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}"
if torch.cuda.is_available():
torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit
return y
class ResnetBlock(nn.Module):
"""
Residual block with two convolutions, group normalization, and swish activation.
Includes skip connection with optional channel dimension matching.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
if chunk_size is None or chunk_size <= 0:
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# Skip connection projection for channel dimension mismatch
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
# Skip connection projection for channel dimension mismatch
if self.in_channels != self.out_channels:
self.nin_shortcut = ChunkedConv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size
)
def forward(self, x: Tensor) -> Tensor:
h = x
# First convolution block
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
# Second convolution block
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
# Apply skip connection with optional projection
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
"""
Spatial downsampling block that reduces resolution by 2x using convolution followed by
pixel rearrangement. Includes skip connection with grouped averaging.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels (must be divisible by 4).
"""
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
super().__init__()
factor = 4 # 2x2 spatial reduction factor
assert out_channels % factor == 0
if chunk_size is None or chunk_size <= 0:
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
else:
self.conv = ChunkedConv2d(
in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
)
self.group_size = factor * in_channels // out_channels
def forward(self, x: Tensor) -> Tensor:
# Apply convolution and rearrange pixels for 2x downsampling
h = self.conv(x)
h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
# Create skip connection with pixel rearrangement
shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
B, C, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
return h + shortcut
class Upsample(nn.Module):
"""
Spatial upsampling block that increases resolution by 2x using convolution followed by
pixel rearrangement. Includes skip connection with channel repetition.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
super().__init__()
factor = 4 # 2x2 spatial expansion factor
if chunk_size is None or chunk_size <= 0:
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
else:
self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
self.repeats = factor * out_channels // in_channels
def forward(self, x: Tensor) -> Tensor:
# Apply convolution and rearrange pixels for 2x upsampling
h = self.conv(x)
h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
# Create skip connection with channel repetition
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
return h + shortcut
class Encoder(nn.Module):
"""
VAE encoder that progressively downsamples input images to a latent representation.
Uses residual blocks, attention, and spatial downsampling.
Parameters
----------
in_channels : int
Number of input image channels (e.g., 3 for RGB).
z_channels : int
Number of latent channels in the output.
block_out_channels : Tuple[int, ...]
Output channels for each downsampling block.
num_res_blocks : int
Number of residual blocks per downsampling stage.
ffactor_spatial : int
Total spatial downsampling factor (e.g., 32 for 32x compression).
"""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
chunk_size: Optional[int] = None,
):
super().__init__()
assert block_out_channels[-1] % (2 * z_channels) == 0
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
if chunk_size is None or chunk_size <= 0:
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
else:
self.conv_in = ChunkedConv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
)
self.down = nn.ModuleList()
block_in = block_out_channels[0]
# Build downsampling blocks
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
# Add residual blocks for this level
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
block_in = block_out
down = nn.Module()
down.block = block
# Add spatial downsampling if needed
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
if add_spatial_downsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1]
down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size)
block_in = block_out
self.down.append(down)
# Middle blocks with attention
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
if chunk_size is None or chunk_size <= 0:
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
else:
self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
def forward(self, x: Tensor) -> Tensor:
# Initial convolution
h = self.conv_in(x)
# Progressive downsampling through blocks
for i_level in range(len(self.block_out_channels)):
# Apply residual blocks at this level
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
# Apply spatial downsampling if available
if hasattr(self.down[i_level], "downsample"):
h = self.down[i_level].downsample(h)
# Middle processing with attention
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# Final output layers with skip connection
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2)
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h += shortcut
return h
class Decoder(nn.Module):
"""
VAE decoder that progressively upsamples latent representations back to images.
Uses residual blocks, attention, and spatial upsampling.
Parameters
----------
z_channels : int
Number of latent channels in the input.
out_channels : int
Number of output image channels (e.g., 3 for RGB).
block_out_channels : Tuple[int, ...]
Output channels for each upsampling block.
num_res_blocks : int
Number of residual blocks per upsampling stage.
ffactor_spatial : int
Total spatial upsampling factor (e.g., 32 for 32x expansion).
"""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
chunk_size: Optional[int] = None,
):
super().__init__()
assert block_out_channels[0] % z_channels == 0
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
block_in = block_out_channels[0]
if chunk_size is None or chunk_size <= 0:
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
else:
self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
# Middle blocks with attention
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
# Build upsampling blocks
self.up = nn.ModuleList()
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
# Add residual blocks for this level (extra block for decoder)
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
block_in = block_out
up = nn.Module()
up.block = block
# Add spatial upsampling if needed
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
if add_spatial_upsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1]
up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size)
block_in = block_out
self.up.append(up)
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
if chunk_size is None or chunk_size <= 0:
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
def forward(self, z: Tensor) -> Tensor:
# Initial processing with skip connection
repeats = self.block_out_channels[0] // self.z_channels
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# Middle processing with attention
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# Progressive upsampling through blocks
for i_level in range(len(self.block_out_channels)):
# Apply residual blocks at this level
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
# Apply spatial upsampling if available
if hasattr(self.up[i_level], "upsample"):
h = self.up[i_level].upsample(h)
# Final output layers
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class HunyuanVAE2D(nn.Module):
"""
VAE model for Hunyuan Image-2.1 with spatial tiling support.
This VAE uses a fixed architecture optimized for the Hunyuan Image-2.1 model,
with 32x spatial compression and optional memory-efficient tiling for large images.
"""
def __init__(self, chunk_size: Optional[int] = None):
super().__init__()
# Fixed configuration for Hunyuan Image-2.1
block_out_channels = (128, 256, 512, 512, 1024, 1024)
in_channels = 3 # RGB input
out_channels = 3 # RGB output
latent_channels = 64
layers_per_block = 2
ffactor_spatial = 32 # 32x spatial compression
sample_size = 384 # Minimum sample size for tiling
scaling_factor = LATENT_SCALING_FACTOR # 0.75289 # Latent scaling factor
self.ffactor_spatial = ffactor_spatial
self.scaling_factor = scaling_factor
self.encoder = Encoder(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
chunk_size=chunk_size,
)
self.decoder = Decoder(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
chunk_size=chunk_size,
)
# Spatial tiling configuration for memory efficiency
self.use_spatial_tiling = False
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // ffactor_spatial
self.tile_overlap_factor = 0.25 # 25% overlap between tiles
@property
def dtype(self):
"""Get the data type of the model parameters."""
return next(self.encoder.parameters()).dtype
@property
def device(self):
"""Get the device of the model parameters."""
return next(self.encoder.parameters()).device
def enable_spatial_tiling(self, use_tiling: bool = True):
"""Enable or disable spatial tiling."""
self.use_spatial_tiling = use_tiling
def disable_spatial_tiling(self):
"""Disable spatial tiling."""
self.use_spatial_tiling = False
def enable_tiling(self, use_tiling: bool = True):
"""Enable or disable spatial tiling (alias for enable_spatial_tiling)."""
self.enable_spatial_tiling(use_tiling)
def disable_tiling(self):
"""Disable spatial tiling (alias for disable_spatial_tiling)."""
self.disable_spatial_tiling()
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
"""
Blend two tensors horizontally with smooth transition.
Parameters
----------
a : torch.Tensor
Left tensor.
b : torch.Tensor
Right tensor.
blend_extent : int
Number of columns to blend.
"""
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
"""
Blend two tensors vertically with smooth transition.
Parameters
----------
a : torch.Tensor
Top tensor.
b : torch.Tensor
Bottom tensor.
blend_extent : int
Number of rows to blend.
"""
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode large images using spatial tiling to reduce memory usage.
Tiles are processed independently and blended at boundaries.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, C, T, H, W) or (B, C, H, W).
"""
# Handle 5D input (B, C, T, H, W) by removing time dimension
original_ndim = x.ndim
if original_ndim == 5:
x = x.squeeze(2)
B, C, H, W = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
"""
Decode large latents using spatial tiling to reduce memory usage.
Tiles are processed independently and blended at boundaries.
Parameters
----------
z : torch.Tensor
Latent tensor of shape (B, C, H, W).
"""
B, C, H, W = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def encode(self, x: Tensor) -> DiagonalGaussianDistribution:
"""
Encode input images to latent representation.
Uses spatial tiling for large images if enabled.
Parameters
----------
x : Tensor
Input image tensor of shape (B, C, H, W) or (B, C, T, H, W).
Returns
-------
DiagonalGaussianDistribution
Latent distribution with mean and logvar.
"""
# Handle 5D input (B, C, T, H, W) by removing time dimension
original_ndim = x.ndim
if original_ndim == 5:
x = x.squeeze(2)
# Use tiling for large images to reduce memory usage
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
h = self.spatial_tiled_encode(x)
else:
h = self.encoder(x)
# Restore time dimension if input was 5D
if original_ndim == 5:
h = h.unsqueeze(2)
posterior = DiagonalGaussianDistribution(h)
return posterior
def decode(self, z: Tensor):
"""
Decode latent representation back to images.
Uses spatial tiling for large latents if enabled.
Parameters
----------
z : Tensor
Latent tensor of shape (B, C, H, W) or (B, C, T, H, W).
Returns
-------
Tensor
Decoded image tensor.
"""
# Handle 5D input (B, C, T, H, W) by removing time dimension
original_ndim = z.ndim
if original_ndim == 5:
z = z.squeeze(2)
# Use tiling for large latents to reduce memory usage
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(z)
else:
decoded = self.decoder(z)
# Restore time dimension if input was 5D
if original_ndim == 5:
decoded = decoded.unsqueeze(2)
return decoded
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D:
logger.info(f"Initializing VAE with chunk_size={chunk_size}")
vae = HunyuanVAE2D(chunk_size=chunk_size)
logger.info(f"Loading VAE from {vae_path}")
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
info = vae.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"Loaded VAE: {info}")
vae.to(device)
return vae

246
library/lora_utils.py Normal file
View File

@@ -0,0 +1,246 @@
import os
import re
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def filter_lora_state_dict(
weights_sd: Dict[str, torch.Tensor],
include_pattern: Optional[str] = None,
exclude_pattern: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
# apply include/exclude patterns
original_key_count = len(weights_sd.keys())
if include_pattern is not None:
regex_include = re.compile(include_pattern)
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
if exclude_pattern is not None:
original_key_count_ex = len(weights_sd.keys())
regex_exclude = re.compile(exclude_pattern)
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
if len(weights_sd) != original_key_count:
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
remaining_keys.sort()
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
if len(weights_sd) == 0:
logger.warning("No keys left after filtering.")
return weights_sd
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[Dict[str, torch.Tensor]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
move_to_device: bool = False,
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
if isinstance(model_files, str):
model_files = [model_files]
extended_model_files = []
for model_file in model_files:
basename = os.path.basename(model_file)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
else:
extended_model_files.append(model_file)
model_files = extended_model_files
logger.info(f"Loading model files: {model_files}")
# load LoRA weights
weight_hook = None
if lora_weights_list is None or len(lora_weights_list) == 0:
lora_weights_list = []
lora_multipliers = []
list_of_lora_weight_keys = []
else:
list_of_lora_weight_keys = []
for lora_sd in lora_weights_list:
lora_weight_keys = set(lora_sd.keys())
list_of_lora_weight_keys.append(lora_weight_keys)
if lora_multipliers is None:
lora_multipliers = [1.0] * len(lora_weights_list)
while len(lora_multipliers) < len(lora_weights_list):
lora_multipliers.append(1.0)
if len(lora_multipliers) > len(lora_weights_list):
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
# Merge LoRA weights into the state dict
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
return model_weight
original_device = model_weight.device
if original_device != calc_device:
model_weight = model_weight.to(calc_device) # to make calculation faster
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
lora_name = "lora_unet_" + lora_name.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
continue
# get LoRA weights
down_weight = lora_sd[down_key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
if len(up_weight.size()) == 4: # use linear projection mismatch
up_weight = up_weight.squeeze(3).squeeze(2)
down_weight = down_weight.squeeze(3).squeeze(2)
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
model_weight = (
model_weight
+ multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
if alpha_key in lora_weight_keys:
lora_weight_keys.remove(alpha_key)
if not keep_on_calc_device and original_device != calc_device:
model_weight = model_weight.to(original_device) # move back to original device
return model_weight
weight_hook = weight_hook_func
state_dict = load_safetensors_with_fp8_optimization_and_hook(
model_files,
fp8_optimization,
calc_device,
move_to_device,
dit_weight_dtype,
target_keys,
exclude_keys,
weight_hook=weight_hook,
)
for lora_weight_keys in list_of_lora_weight_keys:
# check if all LoRA keys are used
if len(lora_weight_keys) > 0:
# if there are still LoRA keys left, it means they are not used in the model
# this is a warning, not an error
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
return state_dict
def load_safetensors_with_fp8_optimization_and_hook(
model_files: list[str],
fp8_optimization: bool,
calc_device: torch.device,
move_to_device: bool = False,
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
"""
if fp8_optimization:
logger.info(
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
)
else:
logger.info(
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
else:
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
if weight_hook is not None:
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
if move_to_device:
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
state_dict[key] = value
if move_to_device:
synchronize_device(calc_device)
return state_dict

View File

@@ -18,10 +18,11 @@ from library import lumina_models, strategy_base, strategy_lumina, train_util
from library.flux_models import AutoEncoder
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.safetensors_utils import mem_eff_save_file
init_ipex()
from .utils import setup_logging, mem_eff_save_file
from .utils import setup_logging
setup_logging()
import logging
@@ -474,11 +475,7 @@ def sample_image_inference(
def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t
@@ -801,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
weighting = torch.ones_like(sigmas)
return weighting
# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -866,14 +844,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -1048,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",

View File

@@ -12,7 +12,7 @@ from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging
from library import lumina_models, flux_models
from library.utils import load_safetensors
from library.safetensors_utils import load_safetensors
import logging
setup_logging()

View File

@@ -0,0 +1,351 @@
import os
import re
import numpy as np
import torch
import json
import struct
from typing import Dict, Any, Union, Optional
from safetensors.torch import load_file
from library.device_utils import synchronize_device
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
"""Memory-efficient reader for safetensors files.
This class provides a memory-efficient way to read tensors from safetensors files
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
def __init__(self, filename):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
"""Enter context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager and close file."""
self.file.close()
def keys(self):
"""Get all tensor keys in the file.
Returns:
list: List of tensor names (excludes metadata).
"""
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
"""Get metadata from the file.
Returns:
Dict[str, str]: Metadata dictionary.
"""
return self.header.get("__metadata__", {})
def _read_header(self):
"""Read and parse the header from the safetensors file.
Returns:
tuple: (header_dict, header_size) containing parsed header and its size.
"""
# Read header size (8 bytes, little-endian unsigned long long)
header_size = struct.unpack("<Q", self.file.read(8))[0]
# Read and decode header JSON
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def get_tensor(self, key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
"""Load a tensor from the file with memory-efficient strategies.
**Note:**
If device is 'cuda' , the transfer to GPU is done efficiently using pinned memory and non-blocking transfer.
So you must ensure that the transfer is completed before using the tensor (e.g., by `torch.cuda.synchronize()`).
If the tensor is large (>10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies.
Args:
key (str): Name of the tensor to load.
device (Optional[torch.device]): Target device for the tensor.
dtype (Optional[torch.dtype]): Target dtype for the tensor.
Returns:
torch.Tensor: The loaded tensor.
Raises:
KeyError: If the tensor key is not found in the file.
"""
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
num_bytes = offset_end - offset_start
original_dtype = self._get_torch_dtype(metadata["dtype"])
target_dtype = dtype if dtype is not None else original_dtype
# Handle empty tensors
if num_bytes == 0:
return torch.empty(metadata["shape"], dtype=target_dtype, device=device)
# Determine if we should use pinned memory for GPU transfer
non_blocking = device is not None and device.type == "cuda"
# Calculate absolute file offset
tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size
# Memory mapping strategy for large tensors to GPU
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
del mm
# Deserialize tensor (view and reshape)
cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape
del byte_tensor
# Transfer to target device and dtype
gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking)
del cpu_tensor
return gpu_tensor
# Standard file reading strategy for smaller tensors or CPU target
# seek to the specified position
self.file.seek(tensor_offset)
# read directly into a numpy array by numpy.fromfile without intermediate copy
numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes)
byte_tensor = torch.from_numpy(numpy_array)
del numpy_array
# deserialize (view and reshape)
deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata)
del byte_tensor
# cast to target dtype and move to device
return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking)
def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict):
"""Deserialize byte tensor to the correct shape and dtype.
Args:
byte_tensor (torch.Tensor): Raw byte tensor from file.
metadata (Dict): Tensor metadata containing dtype and shape info.
Returns:
torch.Tensor: Deserialized tensor with correct shape and dtype.
"""
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
# Handle special float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# Standard conversion: view as target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
"""Convert string dtype to PyTorch dtype.
Args:
dtype_str (str): String representation of the dtype.
Returns:
torch.dtype: Corresponding PyTorch dtype.
"""
# Standard dtype mappings
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# Add float8 types if available in PyTorch version
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
"""Convert byte tensor to float8 format if supported.
Args:
byte_tensor (torch.Tensor): Raw byte tensor.
dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3").
shape (tuple): Target tensor shape.
Returns:
torch.Tensor: Tensor with float8 dtype.
Raises:
ValueError: If float8 type is not supported in current PyTorch version.
"""
# Convert to specific float8 types if available
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# Float8 not supported in this PyTorch version
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
"""
Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix.
dtype is as is, no conversion is done.
"""
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with: Optional[str] = None) -> Optional[str]:
"""
Find a key in a safetensors file that starts with `starts_with` and ends with `ends_with`.
If `starts_with` is None, it will match any key.
If `ends_with` is None, it will match any key.
Returns the first matching key or None if no key matches.
"""
with MemoryEfficientSafeOpen(safetensors_file) as f:
for key in f.keys():
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None

View File

@@ -37,18 +37,16 @@ metadata = {
BASE_METADATA = {
# === MUST ===
"modelspec.sai_model_spec": "1.0.1",
"modelspec.sai_model_spec": "1.0.1",
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === SHOULD ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
"modelspec.hash_sha256": None,
# === CAN===
"modelspec.implementation_version": None,
"modelspec.license": None,
@@ -81,6 +79,8 @@ ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
ARCH_FLUX_1_UNKNOWN = "flux-1"
ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -91,6 +91,7 @@ IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -102,20 +103,20 @@ class ModelSpecMetadata:
ModelSpec 1.0.1 compliant metadata for safetensors models.
All fields correspond to modelspec.* keys in the final metadata.
"""
# === MUST ===
architecture: str
implementation: str
title: str
resolution: str
sai_model_spec: str = "1.0.1"
# === SHOULD ===
description: str | None = None
author: str | None = None
date: str | None = None
hash_sha256: str | None = None
# === CAN ===
implementation_version: str | None = None
license: str | None = None
@@ -131,14 +132,14 @@ class ModelSpecMetadata:
is_negative_embedding: str | None = None
unet_dtype: str | None = None
vae_dtype: str | None = None
# === Additional metadata ===
additional_fields: dict[str, str] = field(default_factory=dict)
def to_metadata_dict(self) -> dict[str, str]:
"""Convert dataclass to metadata dictionary with modelspec. prefixes."""
metadata = {}
# Add all non-None fields with modelspec prefix
for field_name, value in self.__dict__.items():
if field_name == "additional_fields":
@@ -150,14 +151,14 @@ class ModelSpecMetadata:
metadata[f"modelspec.{key}"] = val
elif value is not None:
metadata[f"modelspec.{field_name}"] = value
return metadata
@classmethod
def from_args(cls, args, **kwargs) -> "ModelSpecMetadata":
"""Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields."""
metadata_fields = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
@@ -166,7 +167,7 @@ class ModelSpecMetadata:
# Remove metadata_ prefix
field_name = attr_name[9:] # len("metadata_") = 9
metadata_fields[field_name] = value
# Handle known standard fields
standard_fields = {
"author": metadata_fields.pop("author", None),
@@ -174,30 +175,25 @@ class ModelSpecMetadata:
"license": metadata_fields.pop("license", None),
"tags": metadata_fields.pop("tags", None),
}
# Remove None values
standard_fields = {k: v for k, v in standard_fields.items() if v is not None}
# Merge with kwargs and remaining metadata fields
all_fields = {**standard_fields, **kwargs}
if metadata_fields:
all_fields["additional_fields"] = metadata_fields
return cls(**all_fields)
def determine_architecture(
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
model_config: dict[str, str] | None = None
v2: bool, v_parameterization: bool, sdxl: bool, lora: bool, textual_inversion: bool, model_config: dict[str, str] | None = None
) -> str:
"""Determine model architecture string from parameters."""
model_config = model_config or {}
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif "sd3" in model_config:
@@ -218,17 +214,23 @@ def determine_architecture(
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif "hunyuan_image" in model_config:
hunyuan_image_type = model_config["hunyuan_image"]
if hunyuan_image_type == "2.1":
arch = ARCH_HUNYUAN_IMAGE_2_1
else:
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
# Add adapter suffix
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
return arch
@@ -237,12 +239,12 @@ def determine_implementation(
textual_inversion: bool,
sdxl: bool,
model_config: dict[str, str] | None = None,
is_stable_diffusion_ckpt: bool | None = None
is_stable_diffusion_ckpt: bool | None = None,
) -> str:
"""Determine implementation string from parameters."""
model_config = model_config or {}
if "flux" in model_config:
if model_config["flux"] == "chroma":
return IMPL_CHROMA
@@ -265,16 +267,16 @@ def get_implementation_version() -> str:
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root
timeout=5
timeout=5,
)
if result.returncode == 0:
commit_hash = result.stdout.strip()
return f"sd-scripts/{commit_hash}"
else:
logger.warning("Failed to get git commit hash, using fallback")
return "sd-scripts/unknown"
except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e:
logger.warning(f"Could not determine git commit: {e}")
return "sd-scripts/unknown"
@@ -284,19 +286,19 @@ def file_to_data_url(file_path: str) -> str:
"""Convert a file path to a data URL for embedding in metadata."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
# Get MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
# Default to binary if we can't detect
mime_type = "application/octet-stream"
# Read file and encode as base64
with open(file_path, "rb") as f:
file_data = f.read()
encoded_data = base64.b64encode(file_data).decode("ascii")
return f"data:{mime_type};base64,{encoded_data}"
@@ -305,12 +307,12 @@ def determine_resolution(
sdxl: bool = False,
model_config: dict[str, str] | None = None,
v2: bool = False,
v_parameterization: bool = False
v_parameterization: bool = False,
) -> str:
"""Determine resolution string from parameters."""
model_config = model_config or {}
if reso is not None:
# Handle comma separated string
if isinstance(reso, str):
@@ -318,21 +320,18 @@ def determine_resolution(
# Handle single int
if isinstance(reso, int):
reso = (reso, reso)
# Handle single-element tuple
# Handle single-element tuple
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if (sdxl or
"sd3" in model_config or
"flux" in model_config or
"lumina" in model_config):
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)
else:
reso = (512, 512)
return f"{reso[0]}x{reso[1]}"
@@ -388,23 +387,19 @@ def build_metadata_dataclass(
) -> ModelSpecMetadata:
"""
Build ModelSpec 1.0.1 compliant metadata dataclass.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# Use helper functions for complex logic
architecture = determine_architecture(
v2, v_parameterization, sdxl, lora, textual_inversion, model_config
)
architecture = determine_architecture(v2, v_parameterization, sdxl, lora, textual_inversion, model_config)
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
implementation = determine_implementation(
lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt
)
implementation = determine_implementation(lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt)
if title is None:
if lora:
@@ -421,9 +416,7 @@ def build_metadata_dataclass(
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
# Use helper function for resolution
resolution = determine_resolution(
reso, sdxl, model_config, v2, v_parameterization
)
resolution = determine_resolution(reso, sdxl, model_config, v2, v_parameterization)
# Handle prediction type - Flux models don't use prediction_type
model_config = model_config or {}
@@ -488,7 +481,7 @@ def build_metadata_dataclass(
prediction_type=prediction_type,
timestep_range=timestep_range,
encoder_layer=encoder_layer,
additional_fields=processed_optional_metadata
additional_fields=processed_optional_metadata,
)
return metadata
@@ -518,7 +511,7 @@ def build_metadata(
"""
Build ModelSpec 1.0.1 compliant metadata for safetensors models.
Legacy function that returns dict - prefer build_metadata_dataclass for new code.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
@@ -545,7 +538,7 @@ def build_metadata(
model_config=model_config,
optional_metadata=optional_metadata,
)
return metadata_obj.to_metadata_dict()
@@ -581,7 +574,7 @@ def build_merged_from(models: list[str]) -> str:
def add_model_spec_arguments(parser: argparse.ArgumentParser):
"""Add all ModelSpec metadata arguments to the parser."""
parser.add_argument(
"--metadata_title",
type=str,

View File

@@ -23,7 +23,7 @@ from library import sdxl_model_util
# region models
# TODO remove dependency on flux_utils
from library.utils import load_safetensors
from library.safetensors_utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
@@ -246,7 +246,7 @@ def load_vae(
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap)
vae_sd = load_safetensors(vae_path, device, disable_mmap, dtype=vae_dtype)
else:
# remove prefix "first_stage_model."
vae_sd = {}

View File

@@ -327,14 +327,17 @@ def save_sd_model_on_epoch_end_or_stepwise(
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
if support_text_encoder_caching:
parser.add_argument(
"--cache_text_encoder_outputs",
action="store_true",
help="cache text encoder outputs / text encoderの出力をキャッシュする",
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
@@ -342,7 +345,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.clip_skip is not None:
@@ -365,7 +368,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# not hasattr(args, "weighted_captions") or not args.weighted_captions
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if support_text_encoder_caching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
logger.warning(

View File

@@ -626,6 +626,7 @@ class LatentsCachingStrategy:
for key in npz.files:
kwargs[key] = npz[key]
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)

View File

@@ -0,0 +1,218 @@
import os
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import AutoTokenizer, Qwen2Tokenizer
from library import hunyuan_image_text_encoder, hunyuan_image_vae, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class HunyuanImageTokenizeStrategy(TokenizeStrategy):
def __init__(self, tokenizer_cache_dir: Optional[str] = None) -> None:
self.vlm_tokenizer = self._load_tokenizer(
Qwen2Tokenizer, hunyuan_image_text_encoder.QWEN_2_5_VL_IMAGE_ID, tokenizer_cache_dir=tokenizer_cache_dir
)
self.byt5_tokenizer = self._load_tokenizer(
AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, subfolder="", tokenizer_cache_dir=tokenizer_cache_dir
)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
vlm_tokens, vlm_mask = hunyuan_image_text_encoder.get_qwen_tokens(self.vlm_tokenizer, text)
# byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text)
byt5_tokens = []
byt5_mask = []
for t in text:
tokens, mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, t)
if tokens is None:
tokens = torch.zeros((1, 1), dtype=torch.long)
mask = torch.zeros((1, 1), dtype=torch.long)
byt5_tokens.append(tokens)
byt5_mask.append(mask)
max_len = max([m.shape[1] for m in byt5_mask])
byt5_tokens = torch.cat([torch.nn.functional.pad(t, (0, max_len - t.shape[1]), value=0) for t in byt5_tokens], dim=0)
byt5_mask = torch.cat([torch.nn.functional.pad(m, (0, max_len - m.shape[1]), value=0) for m in byt5_mask], dim=0)
return [vlm_tokens, vlm_mask, byt5_tokens, byt5_mask]
class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
vlm_tokens, vlm_mask, byt5_tokens, byt5_mask = tokens
qwen2vlm, byt5 = models
# autocast and no_grad are handled in hunyuan_image_text_encoder
vlm_embed, vlm_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds_from_tokens(qwen2vlm, vlm_tokens, vlm_mask)
# ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens(
# byt5, byt5_tokens, byt5_mask
# )
ocr_mask, byt5_embed, byt5_updated_mask = [], [], []
for i in range(byt5_tokens.shape[0]):
ocr_m, byt5_e, byt5_m = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens(
byt5, byt5_tokens[i : i + 1], byt5_mask[i : i + 1]
)
ocr_mask.append(torch.zeros((1,), dtype=torch.long) + (1 if ocr_m[0] else 0)) # 1 or 0
byt5_embed.append(byt5_e)
byt5_updated_mask.append(byt5_m)
ocr_mask = torch.cat(ocr_mask, dim=0).to(torch.bool) # [B]
byt5_embed = torch.cat(byt5_embed, dim=0)
byt5_updated_mask = torch.cat(byt5_updated_mask, dim=0)
return [vlm_embed, vlm_mask, byt5_embed, byt5_updated_mask, ocr_mask]
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
vln_embed = data["vlm_embed"]
vlm_mask = data["vlm_mask"]
byt5_embed = data["byt5_embed"]
byt5_mask = data["byt5_mask"]
ocr_mask = data["ocr_mask"]
return [vln_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
huyuan_image_text_encoding_strategy: HunyuanImageTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = huyuan_image_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = ocr_mask.cpu().numpy()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(32, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self, vae: hunyuan_image_vae.HunyuanVAE2D, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
):
# encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample()
def encode_by_vae(img_tensor):
# no_grad is handled in _default_cache_batch_latents
nonlocal vae
with torch.autocast(device_type=vae.device.type, dtype=vae.dtype):
return vae.encode(img_tensor).sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -1131,7 +1131,8 @@ class BaseDataset(torch.utils.data.Dataset):
def __eq__(self, other):
return (
self.reso == other.reso
other is not None
and self.reso == other.reso
and self.flip_aug == other.flip_aug
and self.alpha_mask == other.alpha_mask
and self.random_crop == other.random_crop
@@ -1193,6 +1194,8 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0 and current_condition != condition:
submit_batch(batch, current_condition)
batch = []
if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed
clean_memory_on_device(accelerator.device)
if info.image is None:
# load image in parallel
@@ -1205,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) >= caching_strategy.batch_size:
submit_batch(batch, current_condition)
batch = []
current_condition = None
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
if len(batch) > 0:
submit_batch(batch, current_condition)
@@ -1744,7 +1747,35 @@ class BaseDataset(torch.utils.data.Dataset):
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
return None
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
# old implementation without padding: all elements must have same length
# return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
# new implementation with padding support
result = []
for i in range(len(tensors_list[0])):
tensors = [x[i] for x in tensors_list]
if tensors[0].ndim == 0:
# scalar value: e.g. ocr mask
result.append(torch.stack([converter(x[i]) for x in tensors_list]))
continue
min_len = min([len(x) for x in tensors])
max_len = max([len(x) for x in tensors])
if min_len == max_len:
# no padding
result.append(torch.stack([converter(x) for x in tensors]))
else:
# padding
tensors = [converter(x) for x in tensors]
if tensors[0].ndim == 1:
# input_ids or mask
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
else:
# text encoder outputs
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
return result
# set example
example = {}
@@ -2170,6 +2201,23 @@ class FineTuningDataset(BaseDataset):
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.latents_cache = None
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
self.num_train_images = 0
self.num_reg_images = 0
@@ -2189,9 +2237,25 @@ class FineTuningDataset(BaseDataset):
# メタデータを読み込む
if os.path.exists(subset.metadata_file):
logger.info(f"loading existing metadata: {subset.metadata_file}")
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
if subset.metadata_file.endswith(".jsonl"):
logger.info(f"loading existing JSOL metadata: {subset.metadata_file}")
# optional JSONL format
# {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1", "image_size": [width, height]}
metadata = {}
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
for line in f:
line_md = json.loads(line)
image_md = {"caption": line_md.get("caption", "")}
if "image_size" in line_md:
image_md["image_size"] = line_md["image_size"]
if "tags" in line_md:
image_md["tags"] = line_md["tags"]
metadata[line_md["image_path"]] = image_md
else:
# standard JSON format
logger.info(f"loading existing metadata: {subset.metadata_file}")
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
@@ -2201,65 +2265,101 @@ class FineTuningDataset(BaseDataset):
)
continue
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
abs_path = None
# まず画像を優先して探す
if os.path.exists(image_key):
abs_path = image_key
# Add full path for image
image_dirs = set()
if subset.image_dir is not None:
image_dirs.add(subset.image_dir)
for image_key in metadata.keys():
if not os.path.isabs(image_key):
assert (
subset.image_dir is not None
), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}"
abs_path = os.path.join(subset.image_dir, image_key)
else:
# わりといい加減だがいい方法が思いつかん
paths = glob_images(subset.image_dir, image_key)
if len(paths) > 0:
abs_path = paths[0]
abs_path = image_key
image_dirs.add(os.path.dirname(abs_path))
metadata[image_key]["abs_path"] = abs_path
# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
# Enumerate existing npz files
strategy = LatentsCachingStrategy.get_strategy()
npz_paths = []
for image_dir in image_dirs:
npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix)))
npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
# Match image filename longer to shorter because some images share same prefix
image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True)
# Collect tags and sizes
tags_list = []
size_set_from_metadata = 0
size_set_from_cache_filename = 0
for image_key in image_keys_sorted_by_length_desc:
img_md = metadata[image_key]
caption = img_md.get("caption")
tags = img_md.get("tags")
image_size = img_md.get("image_size")
abs_path = img_md.get("abs_path")
# search npz if image_size is not given
npz_path = None
if image_size is None:
image_without_ext = os.path.splitext(image_key)[0]
for candidate in npz_paths:
if candidate.startswith(image_without_ext):
npz_path = candidate
break
if npz_path is not None:
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
abs_path = npz_path
if caption is None:
caption = tags # could be multiline
tags = None
caption = ""
if subset.enable_wildcard:
# tags must be single line
# tags must be single line (split by caption separator)
if tags is not None:
tags = tags.replace("\n", subset.caption_separator)
# add tags to each line of caption
if caption is not None and tags is not None:
if tags is not None:
caption = "\n".join(
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
)
tags_list.append(tags)
else:
# use as is
if tags is not None and len(tags) > 0:
caption = caption + subset.caption_separator + tags
if len(caption) > 0:
caption = caption + subset.caption_separator
caption = caption + tags
tags_list.append(tags)
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info.image_size = img_md.get("train_resolution")
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
if image_size is not None:
image_info.image_size = tuple(image_size) # width, height
size_set_from_metadata += 1
elif npz_path is not None:
# get image size from npz filename
w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path)
image_info.image_size = (w, h)
size_set_from_cache_filename += 1
self.register_image(image_info, subset)
if size_set_from_cache_filename > 0:
logger.info(
f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}"
)
if size_set_from_metadata > 0:
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
self.num_train_images += len(metadata) * subset.num_repeats
# TODO do not record tag freq when no tag
@@ -2267,117 +2367,6 @@ class FineTuningDataset(BaseDataset):
subset.img_count = len(metadata)
self.subsets.append(subset)
# check existence of all npz files
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
flip_aug_in_subset = False
npz_any = False
npz_all = True
for image_info in self.image_data.values():
subset = self.image_to_subset[image_info.image_key]
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
break
if not npz_any:
use_npz_latents = False
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
elif not npz_all:
use_npz_latents = False
logger.warning(
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
)
if flip_aug_in_subset:
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
# check min/max bucket size
sizes = set()
resos = set()
for image_info in self.image_data.values():
if image_info.image_size is None:
sizes = None # not calculated
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
resos.add(tuple(image_info.image_size))
if sizes is None:
if use_npz_latents:
use_npz_latents = False
logger.warning(
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
)
assert (
resolution is not None
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
if not enable_bucket:
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
self.enable_bucket = True
assert (
not bucket_no_upscale
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
# bucket情報を初期化しておく、make_bucketsで再作成しない
self.bucket_manager = BucketManager(False, None, None, None, None)
self.bucket_manager.set_predefined_resos(resos)
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
if os.path.exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None
# image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
class ControlNetDataset(BaseDataset):
def __init__(
@@ -3588,6 +3577,7 @@ def get_sai_model_spec_dataclass(
sd3: str = None,
flux: str = None,
lumina: str = None,
hunyuan_image: str = None,
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
@@ -3617,6 +3607,8 @@ def get_sai_model_spec_dataclass(
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina
if hunyuan_image is not None:
model_config["hunyuan_image"] = hunyuan_image
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
@@ -3987,11 +3979,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
"--full_fp16",
action="store_true",
help="fp16 training including gradients, some models are not supported / 勾配も含めてfp16で学習する、一部のモデルではサポートされていません",
)
parser.add_argument(
"--full_bf16",
action="store_true",
help="bf16 training including gradients, some models are not supported / 勾配も含めてbf16で学習する、一部のモデルではサポートされていません",
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--fp8_base",
action="store_true",
help="use fp8 for base model, some models are not supported / base modelにfp8を使う、一部のモデルではサポートされていません",
)
parser.add_argument(
"--ddp_timeout",
@@ -6305,6 +6307,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
m = re.match(r"fs (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["flow_shift"] = m.group(1)
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)

View File

@@ -2,8 +2,6 @@ import logging
import sys
import threading
from typing import *
import json
import struct
import torch
import torch.nn as nn
@@ -14,7 +12,7 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -88,6 +86,7 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
@@ -190,190 +189,6 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
raise ValueError(f"Unsupported dtype: {s}")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
@@ -398,7 +213,14 @@ def pil_resize(image, size, interpolation):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
def resize_image(
image: np.ndarray,
width: int,
height: int,
resized_width: int,
resized_height: int,
resize_interpolation: Optional[str] = None,
):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
@@ -449,29 +271,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
@@ -479,7 +302,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
@@ -493,7 +316,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
@@ -503,12 +326,14 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion
# TODO make inf_utils.py
@@ -642,7 +467,9 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]

View File

@@ -743,7 +743,7 @@ def train(args):
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32

View File

@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)

View File

@@ -0,0 +1,88 @@
import argparse
from safetensors.torch import save_file
from safetensors import safe_open
import torch
from library import train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
# load source safetensors
logger.info(f"Loading source file {args.src_path}")
state_dict = {}
with safe_open(args.src_path, framework="pt") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
logger.info(f"Converting...")
# Key mapping tables: (sd-scripts format, ComfyUI format)
double_blocks_mappings = [
("img_mlp_fc1", "img_mlp_0"),
("img_mlp_fc2", "img_mlp_2"),
("img_mod_linear", "img_mod_lin"),
("txt_mlp_fc1", "txt_mlp_0"),
("txt_mlp_fc2", "txt_mlp_2"),
("txt_mod_linear", "txt_mod_lin"),
]
single_blocks_mappings = [
("modulation_linear", "modulation_lin"),
]
keys = list(state_dict.keys())
count = 0
for k in keys:
new_k = k
if "double_blocks" in k:
mappings = double_blocks_mappings
elif "single_blocks" in k:
mappings = single_blocks_mappings
else:
continue
# Apply mappings based on conversion direction
for src_key, dst_key in mappings:
if args.reverse:
# ComfyUI to sd-scripts: swap src and dst
new_k = new_k.replace(dst_key, src_key)
else:
# sd-scripts to ComfyUI: use as-is
new_k = new_k.replace(src_key, dst_key)
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
count += 1
# print(f"Renamed {k} to {new_k}")
logger.info(f"Converted {count} keys")
# Calculate hash
if metadata is not None:
logger.info(f"Calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
# save destination safetensors
logger.info(f"Saving destination file {args.dst_path}")
save_file(state_dict, args.dst_path, metadata=metadata)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA format")
parser.add_argument("src_path", type=str, default=None, help="source path, sd-scripts format")
parser.add_argument("dst_path", type=str, default=None, help="destination path, ComfyUI format")
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
args = parser.parse_args()
main(args)

View File

@@ -10,9 +10,8 @@ import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
from library import flux_utils, sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import MemoryEfficientSafeOpen
from library import flux_utils, sai_model_spec
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.utils import setup_logging
from networks import lora_flux
@@ -140,7 +139,9 @@ def svd(
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
sai_metadata = sai_model_spec.build_metadata(
lora_sd, False, False, False, True, False, time.time(), title, model_config={"flux": "dev"}
)
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, metadata, save_dtype)

View File

@@ -9,7 +9,8 @@ from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
from library.utils import setup_logging, str_to_dtype
from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
@@ -618,7 +619,16 @@ def merge(args):
merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
None,
False,
False,
False,
False,
False,
time.time(),
title=title,
merged_from=merged_from,
model_config={"flux": "dev"},
)
if flux_state_dict is not None and len(flux_state_dict) > 0:
@@ -646,7 +656,16 @@ def merge(args):
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
flux_state_dict,
False,
False,
False,
True,
False,
time.time(),
title=title,
merged_from=merged_from,
model_config={"flux": "dev"},
)
metadata.update(sai_metadata)

View File

@@ -713,6 +713,10 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
@classmethod
def get_qkv_mlp_split_dims(cls) -> List[int]:
return [3072] * 3 + [12288]
def __init__(
self,
text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
@@ -842,7 +846,7 @@ class LoRANetwork(torch.nn.Module):
break
# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
if dim is None and modules_dim is None:
if dim is None and modules_dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
@@ -901,9 +905,9 @@ class LoRANetwork(torch.nn.Module):
split_dims = None
if is_flux and split_qkv:
if "double" in lora_name and "qkv" in lora_name:
split_dims = [3072] * 3
(split_dims,) = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in lora_name and "linear1" in lora_name:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
lora = module_class(
lora_name,
@@ -1036,9 +1040,9 @@ class LoRANetwork(torch.nn.Module):
# split qkv
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
else:
continue
@@ -1092,9 +1096,9 @@ class LoRANetwork(torch.nn.Module):
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
else:
new_state_dict[key] = state_dict[key]
continue

View File

@@ -0,0 +1,378 @@
# temporary minimum implementation of LoRA
# FLUX doesn't have Conv2d, so we ignore it
# TODO commonize with the original implementation
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import os
from typing import Dict, List, Optional, Type, Union
import torch
import torch.nn as nn
from torch import Tensor
import re
from networks import lora_flux
from library.hunyuan_image_vae import HunyuanVAE2D
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 20
NUM_SINGLE_BLOCKS = 40
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: HunyuanVAE2D,
text_encoders: List[nn.Module],
flux,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
ggpo_beta = kwargs.get("ggpo_beta", None)
ggpo_sigma = kwargs.get("ggpo_sigma", None)
if ggpo_beta is not None:
ggpo_beta = float(ggpo_beta)
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
# regex-specific learning rates
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
# parse regular expression based learning rates
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
# regex-specific dimensions (ranks)
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
# Too many arguments ( ^ω^)・・・
network = HunyuanImageLoRANetwork(
text_encoders,
flux,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
split_qkv=split_qkv,
reg_dims=reg_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
reg_lrs=reg_lrs,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = lora_flux.LoRAInfModule if for_inference else lora_flux.LoRAModule
network = HunyuanImageLoRANetwork(
text_encoders,
flux,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
)
return network, weights_sd
class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
TARGET_REPLACE_MODULE_DOUBLE = ["MMDoubleStreamBlock"]
TARGET_REPLACE_MODULE_SINGLE = ["MMSingleStreamBlock"]
LORA_PREFIX_HUNYUAN_IMAGE_DIT = "lora_unet" # make ComfyUI compatible
@classmethod
def get_qkv_mlp_split_dims(cls) -> List[int]:
return [3584] * 3 + [14336]
def __init__(
self,
text_encoders: list[nn.Module],
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = lora_flux.LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
split_qkv: bool = False,
reg_dims: Optional[Dict[str, int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
nn.Module.__init__(self)
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.in_dims = [0] * 5 # create in_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.split_qkv:
logger.info(f"split qkv for LoRA")
# create module instances
def create_modules(
is_dit: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[lora_flux.LoRAModule]:
assert is_dit, "only DIT is supported now"
prefix = self.LORA_PREFIX_HUNYUAN_IMAGE_DIT
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
if filter is not None and not filter in lora_name:
continue
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.search(reg, lora_name):
dim = d
alpha = self.alpha
logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}")
break
# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
if dim is None and modules_dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
skipped.append(lora_name)
continue
# qkv split
split_dims = None
if is_dit and split_qkv:
if "double" in lora_name and "qkv" in lora_name:
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in lora_name and "linear1" in lora_name:
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
return loras, skipped
# create LoRA for U-Net
target_replace_modules = (
HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_SINGLE
)
self.unet_loras: List[Union[lora_flux.LoRAModule, lora_flux.LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
self.text_encoder_loras = []
logger.info(f"create LoRA for HunyuanImage-2.1: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_un
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

View File

@@ -9,7 +9,7 @@ einops==0.7.0
bitsandbytes
lion-pytorch==0.2.3
schedulefree==1.4
pytorch-optimizer==3.7.0
pytorch-optimizer==3.9.0
prodigy-plus-schedule-free==1.9.2
prodigyopt==1.1.2
tensorboard

View File

@@ -28,7 +28,7 @@ import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_sd3
from library.utils import load_safetensors
from library.safetensors_utils import load_safetensors
def get_noise(seed, latent, device="cpu"):

View File

@@ -14,6 +14,7 @@ from tqdm import tqdm
import torch
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.safetensors_utils import load_safetensors
init_ipex()
@@ -206,7 +207,7 @@ def train(args):
# t5xxl_dtype = weight_dtype
model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx)
if args.clip_l is None:
sd3_state_dict = utils.load_safetensors(
sd3_state_dict = load_safetensors(
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
)
else:
@@ -322,7 +323,7 @@ def train(args):
# load VAE for caching latents
if sd3_state_dict is None:
logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}")
sd3_state_dict = utils.load_safetensors(
sd3_state_dict = load_safetensors(
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
)

View File

@@ -8,6 +8,7 @@ import torch
from accelerate import Accelerator
from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.safetensors_utils import load_safetensors
init_ipex()
@@ -77,7 +78,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
state_dict = utils.load_safetensors(
state_dict = load_safetensors(
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")

View File

@@ -20,7 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
# super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:

View File

@@ -19,11 +19,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
def test_batchify():
# Test case with no batch size specified
prompts = [
{"prompt": "test1"},
{"prompt": "test2"},
{"prompt": "test3"}
]
prompts = [{"prompt": "test1"}, {"prompt": "test2"}, {"prompt": "test3"}]
batchified = list(batchify(prompts))
assert len(batchified) == 1
assert len(batchified[0]) == 3
@@ -38,7 +34,7 @@ def test_batchify():
prompts_with_params = [
{"prompt": "test1", "width": 512, "height": 512},
{"prompt": "test2", "width": 512, "height": 512},
{"prompt": "test3", "width": 1024, "height": 1024}
{"prompt": "test3", "width": 1024, "height": 1024},
]
batchified_params = list(batchify(prompts_with_params))
assert len(batchified_params) == 2
@@ -61,7 +57,7 @@ def test_time_shift():
# Test with edge cases
t_edges = torch.tensor([0.0, 1.0])
result_edges = time_shift(1.0, 1.0, t_edges)
# Check that results are bounded within [0, 1]
assert torch.all(result_edges >= 0)
assert torch.all(result_edges <= 1)
@@ -93,10 +89,7 @@ def test_get_schedule():
# Test with shift disabled
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
assert torch.allclose(
torch.tensor(unshifted_schedule),
torch.linspace(1, 1/10, 10)
)
assert torch.allclose(torch.tensor(unshifted_schedule), torch.linspace(1, 1 / 10, 10))
def test_compute_density_for_timestep_sampling():
@@ -106,16 +99,12 @@ def test_compute_density_for_timestep_sampling():
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
# Test logit normal sampling
logit_normal_samples = compute_density_for_timestep_sampling(
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
)
logit_normal_samples = compute_density_for_timestep_sampling("logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0)
assert len(logit_normal_samples) == 100
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
# Test mode sampling
mode_samples = compute_density_for_timestep_sampling(
"mode", batch_size=100, mode_scale=0.5
)
mode_samples = compute_density_for_timestep_sampling("mode", batch_size=100, mode_scale=0.5)
assert len(mode_samples) == 100
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
@@ -123,20 +112,20 @@ def test_compute_density_for_timestep_sampling():
def test_get_sigmas():
# Create a mock noise scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
device = torch.device("cpu")
# Test with default parameters
timesteps = torch.tensor([100, 500, 900])
sigmas = get_sigmas(scheduler, timesteps, device)
# Check shape and basic properties
assert sigmas.shape[0] == 3
assert torch.all(sigmas >= 0)
# Test with different n_dim
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
assert sigmas_4d.ndim == 4
# Test with different dtype
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
assert sigmas_float16.dtype == torch.float16
@@ -145,17 +134,17 @@ def test_get_sigmas():
def test_compute_loss_weighting_for_sd3():
# Prepare some mock sigmas
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test sigma_sqrt weighting
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
# Test cosmap weighting
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
bot = 1 - 2 * sigmas + 2 * sigmas**2
expected_cosmap = 2 / (math.pi * bot)
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
# Test default weighting
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
assert torch.all(default_weighting == 1)
@@ -166,22 +155,22 @@ def test_apply_model_prediction_type():
class MockArgs:
model_prediction_type = "raw"
weighting_scheme = "sigma_sqrt"
args = MockArgs()
model_pred = torch.tensor([1.0, 2.0, 3.0])
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test raw prediction type
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(raw_pred == model_pred)
assert raw_weighting is None
# Test additive prediction type
args.model_prediction_type = "additive"
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(additive_pred == model_pred + noisy_model_input)
# Test sigma scaled prediction type
args.model_prediction_type = "sigma_scaled"
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
@@ -192,12 +181,12 @@ def test_apply_model_prediction_type():
def test_retrieve_timesteps():
# Create a mock scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
# Test with num_inference_steps
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
assert len(timesteps) == 50
assert n_steps == 50
# Test error handling with simultaneous timesteps and sigmas
with pytest.raises(ValueError):
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
@@ -210,32 +199,30 @@ def test_get_noisy_model_input_and_timesteps():
weighting_scheme = "sigma_sqrt"
sigmoid_scale = 1.0
discrete_flow_shift = 6.0
ip_noise_gamma = True
ip_noise_gamma_random_strength = 0.01
args = MockArgs()
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
device = torch.device("cpu")
# Prepare mock latents and noise
latents = torch.randn(4, 16, 64, 64)
noise = torch.randn_like(latents)
# Test uniform sampling
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
# Validate output shapes and types
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]
assert noisy_input.dtype == torch.float32
assert timesteps.dtype == torch.float32
# Test different sampling methods
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
for method in sampling_methods:
args.timestep_sampling = method
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
from unittest.mock import patch, MagicMock
from library.custom_offloading_utils import (
synchronize_device,
_synchronize_device,
swap_weight_devices_cuda,
swap_weight_devices_no_cuda,
weighs_to_device,
@@ -50,21 +50,21 @@ class SimpleModel(nn.Module):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_synchronize(mock_cuda_sync):
device = torch.device('cuda')
synchronize_device(device)
_synchronize_device(device)
mock_cuda_sync.assert_called_once()
@patch('torch.xpu.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
def test_xpu_synchronize(mock_xpu_sync):
device = torch.device('xpu')
synchronize_device(device)
_synchronize_device(device)
mock_xpu_sync.assert_called_once()
@patch('torch.mps.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
def test_mps_synchronize(mock_mps_sync):
device = torch.device('mps')
synchronize_device(device)
_synchronize_device(device)
mock_mps_sync.assert_called_once()
@@ -111,7 +111,7 @@ def test_swap_weight_devices_cuda():
@patch('library.custom_offloading_utils.synchronize_device')
@patch('library.custom_offloading_utils._synchronize_device')
def test_swap_weight_devices_no_cuda(mock_sync_device):
device = torch.device('cpu')
layer_to_cpu = SimpleModel()
@@ -121,7 +121,7 @@ def test_swap_weight_devices_no_cuda(mock_sync_device):
with patch('torch.Tensor.copy_'):
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
# Verify synchronize_device was called twice
# Verify _synchronize_device was called twice
assert mock_sync_device.call_count == 2
@@ -279,8 +279,8 @@ def test_backward_hook_execution(mock_wait, mock_submit):
@patch('library.custom_offloading_utils.weighs_to_device')
@patch('library.custom_offloading_utils.synchronize_device')
@patch('library.custom_offloading_utils.clean_memory_on_device')
@patch('library.custom_offloading_utils._synchronize_device')
@patch('library.custom_offloading_utils._clean_memory_on_device')
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
model = SimpleModel(4)
blocks = model.blocks
@@ -291,7 +291,7 @@ def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weight
# Check that weighs_to_device was called for each block
assert mock_weights_to_device.call_count == 4
# Check that synchronize_device and clean_memory_on_device were called
# Check that _synchronize_device and _clean_memory_on_device were called
mock_sync.assert_called_once_with(model_offloader.device)
mock_clean.assert_called_once_with(model_offloader.device)

View File

@@ -30,7 +30,8 @@ import torch
from tqdm import tqdm
from library import flux_utils
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
from library.utils import setup_logging, str_to_dtype
from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
@@ -56,7 +57,7 @@ def convert(args):
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38)
# iterate over three safetensors files to reduce memory usage
flux_sd = {}

View File

@@ -6,7 +6,8 @@ import torch
from safetensors.torch import safe_open
from library.utils import setup_logging
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
from library.utils import str_to_dtype
from library.safetensors_utils import load_safetensors, mem_eff_save_file
setup_logging()
import logging

View File

@@ -1,3 +1,4 @@
import gc
import importlib
import argparse
import math
@@ -10,11 +11,11 @@ import time
import json
from multiprocessing import Value
import numpy as np
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
@@ -175,7 +176,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
@@ -185,6 +186,9 @@ class NetworkTrainer:
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]:
raise NotImplementedError()
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
@@ -475,6 +479,15 @@ class NetworkTrainer:
return loss.mean()
def cast_text_encoder(self, args):
return True # default for other than HunyuanImage
def cast_vae(self, args):
return True # default for other than HunyuanImage
def cast_unet(self, args):
return True # default for other than HunyuanImage
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -583,37 +596,18 @@ class NetworkTrainer:
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
vae_dtype = (torch.float32 if args.no_half_vae else weight_dtype) if self.cast_vae(args) else None
# モデルを読み込む
# load target models: unet may be None for lazy loading
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
if vae_dtype is None:
vae_dtype = vae.dtype
logger.info(f"vae_dtype is set to {vae_dtype} by the model since cast_vae() is false")
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
# prepare dataset for latents caching if needed
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
@@ -640,6 +634,32 @@ class NetworkTrainer:
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
if unet is None:
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# prepare network
net_kwargs = {}
if args.network_args is not None:
@@ -669,7 +689,7 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# TODO remove `hasattr` by setting up methods if not defined in the network like below (hacky but will work):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
@@ -827,12 +847,13 @@ class NetworkTrainer:
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
if self.cast_unet(args):
unet.to(dtype=unet_weight_dtype)
for i, t_enc in enumerate(text_encoders):
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
if t_enc.device.type != "cpu" and self.cast_text_encoder(args):
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
@@ -858,7 +879,8 @@ class NetworkTrainer:
# default implementation is: unet = accelerator.prepare(unet)
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
# move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype if self.cast_unet(args) else None)
if train_text_encoder:
text_encoders = [
(accelerator.prepare(t_enc) if flag else t_enc)
@@ -1302,6 +1324,8 @@ class NetworkTrainer:
del t_enc
text_encoders = []
text_encoder = None
gc.collect()
clean_memory_on_device(accelerator.device)
# For --sample_at_first
optimizer_eval_fn()