mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Compare commits
50 Commits
v0.10.0
...
6c5c307f94
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c5c307f94 | ||
|
|
fa53f71ec0 | ||
|
|
1dae34b0af | ||
|
|
dd7a666727 | ||
|
|
b2c330407b | ||
|
|
c018765583 | ||
|
|
3cb9025b4b | ||
|
|
adf4b7b9c0 | ||
|
|
b637c31365 | ||
|
|
7cbae516c1 | ||
|
|
5fb3172baf | ||
|
|
5cdad10de5 | ||
|
|
89b246f3f6 | ||
|
|
4be0e94fad | ||
|
|
0e168dd1eb | ||
|
|
2723a75f91 | ||
|
|
5f793fb0f4 | ||
|
|
feb38356ea | ||
|
|
cdb49f9fe7 | ||
|
|
bd19e4c15d | ||
|
|
343c929e39 | ||
|
|
b2abe873a5 | ||
|
|
7c159291e9 | ||
|
|
1cd95b2d8b | ||
|
|
1bd0b0faf1 | ||
|
|
d633b51126 | ||
|
|
1a3ec9ea74 | ||
|
|
e1aedceffa | ||
|
|
2217704ce1 | ||
|
|
f90fa1a89a | ||
|
|
98a42e4cd6 | ||
|
|
892f8be78f | ||
|
|
50694df3cf | ||
|
|
609d1292f6 | ||
|
|
48d368fa55 | ||
|
|
3265f2edfb | ||
|
|
ef051427df | ||
|
|
573a7fa06c | ||
|
|
ae72efb92b | ||
|
|
449e70b4cf | ||
|
|
b237b8deb3 | ||
|
|
34e7138b6a | ||
|
|
9144463f7b | ||
|
|
1640e53392 | ||
|
|
e21a7736f8 | ||
|
|
8b5ce3e641 | ||
|
|
da07e4c617 | ||
|
|
872124c5e1 | ||
|
|
ca6b68ef7d | ||
|
|
9c1168a088 |
@@ -21,6 +21,9 @@ Each supported model family has a consistent structure:
|
||||
- **SDXL**: `sdxl_train*.py`, `library/sdxl_*`
|
||||
- **SD3**: `sd3_train*.py`, `library/sd3_*`
|
||||
- **FLUX.1**: `flux_train*.py`, `library/flux_*`
|
||||
- **Lumina Image 2.0**: `lumina_train*.py`, `library/lumina_*`
|
||||
- **HunyuanImage-2.1**: `hunyuan_image_train*.py`, `library/hunyuan_image_*`
|
||||
- **Anima-Preview**: `anima_train*.py`, `library/anima_*`
|
||||
|
||||
### Key Components
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,3 +11,5 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
.codex-tmp
|
||||
references
|
||||
|
||||
55
README-ja.md
55
README-ja.md
@@ -8,25 +8,25 @@
|
||||
<summary>クリックすると展開します</summary>
|
||||
|
||||
- [はじめに](#はじめに)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [スポンサー](#スポンサー)
|
||||
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
|
||||
- [更新履歴](#更新履歴)
|
||||
- [サポートモデル](#サポートモデル)
|
||||
- [機能](#機能)
|
||||
- [ドキュメント](#ドキュメント)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
|
||||
- [その他のドキュメント](#その他のドキュメント)
|
||||
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
|
||||
- [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ)
|
||||
- [Windows環境でのインストール](#windows環境でのインストール)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
|
||||
- [インストール手順](#インストール手順)
|
||||
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
|
||||
- [xformersのインストール(オプション)](#xformersのインストールオプション)
|
||||
- [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ)
|
||||
- [アップグレード](#アップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [PyTorchのアップグレード](#pytorchのアップグレード)
|
||||
- [謝意](#謝意)
|
||||
- [ライセンス](#ライセンス)
|
||||
|
||||
@@ -50,10 +50,31 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
- `networks/resize_lora.py`が`torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。
|
||||
- デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できます(デフォルトは2、多いほど精度が向上します)。0にすると従来の方法になります。詳細は `--help` でご確認ください。
|
||||
- LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275)
|
||||
- 詳細は[ドキュメント](./docs/loha_lokr.md)をご覧ください。
|
||||
- マルチ解像度データセット(同じ画像を複数のbucketサイズにリサイズして使用)がSD/SDXLの学習でサポートされました。[PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) また、マルチ解像度データセットで同じ解像度の画像が重複して使用される事象への対応を行いました。[PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273)
|
||||
- woct0rdho氏に感謝します。
|
||||
- [ドキュメント英語版](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [ドキュメント日本語版](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) をご覧ください。
|
||||
- Animaでfp16で学習する際の安定性が向上しました。[PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) ただし、依然として不安定な場合があるようです。問題が発生する場合は、詳細をIssueでお知らせください。
|
||||
- その他、細かいバグ修正や改善を行いました。
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
|
||||
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
|
||||
|
||||
- **Version 0.10.0 (2026-01-19):**
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
|
||||
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
|
||||
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
|
||||
|
||||
### サポートモデル
|
||||
|
||||
|
||||
45
README.md
45
README.md
@@ -7,23 +7,23 @@
|
||||
<summary>Click to expand</summary>
|
||||
|
||||
- [Introduction](#introduction)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Features](#features)
|
||||
- [Sponsors](#sponsors)
|
||||
- [Support the Project](#support-the-project)
|
||||
- [Documentation](#documentation)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
|
||||
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
|
||||
- [For Developers Using AI Coding Agents](#for-developers-using-ai-coding-agents)
|
||||
- [Windows Installation](#windows-installation)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Windows Required Dependencies](#windows-required-dependencies)
|
||||
- [Installation Steps](#installation-steps)
|
||||
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
|
||||
- [xformers installation (optional)](#xformers-installation-optional)
|
||||
- [Linux/WSL2 Installation](#linuxwsl2-installation)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
|
||||
- [Upgrade](#upgrade)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Upgrade PyTorch](#upgrade-pytorch)
|
||||
- [Credits](#credits)
|
||||
- [License](#license)
|
||||
|
||||
@@ -47,6 +47,27 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
- `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296).
|
||||
- It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details.
|
||||
- LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details.
|
||||
- Please refer to the [documentation](./docs/loha_lokr.md) for details.
|
||||
- Multi-resolution datasets (using the same image resized to multiple bucket sizes) are now supported in SD/SDXL training. We also addressed the issue of duplicate images with the same resolution being used in multi-resolution datasets. See [PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) and [PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) for details.
|
||||
- Thanks to woct0rdho for the contribution.
|
||||
- Please refer to the [English documentation](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [Japanese documentation](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) for details.
|
||||
- Stability when training with fp16 on Anima has been improved. See [PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) for details. However, it still seems to be unstable in some cases. If you encounter any issues, please let us know the details via Issues.
|
||||
- Other minor bug fixes and improvements were made.
|
||||
|
||||
- **Version 0.10.1 (2026-02-13):**
|
||||
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261).
|
||||
- Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260.
|
||||
- For details, please refer to the [documentation](./docs/anima_train_network.md).
|
||||
|
||||
- **Version 0.10.0 (2026-01-19):**
|
||||
- `sd3` branch is merged to `main` branch. From this version, FLUX.1 and SD3/SD3.5 etc. are supported in the `main` branch.
|
||||
- There are still some missing parts in the documentation, so please let us know if you find any issues via Issues etc.
|
||||
|
||||
@@ -32,6 +32,7 @@ hime="hime"
|
||||
OT="OT"
|
||||
byt="byt"
|
||||
tak="tak"
|
||||
temperal="temperal"
|
||||
|
||||
[files]
|
||||
extend-exclude = ["_typos.toml", "venv"]
|
||||
extend-exclude = ["_typos.toml", "venv", "configs"]
|
||||
|
||||
1082
anima_minimal_inference.py
Normal file
1082
anima_minimal_inference.py
Normal file
File diff suppressed because it is too large
Load Diff
759
anima_train.py
Normal file
759
anima_train.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# Anima full finetune training script
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import deepspeed_utils, anima_models, anima_train_utils, anima_utils, strategy_base, strategy_anima, sai_model_spec
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# backward compatibility
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# prepare caching strategy: must be set before preparing dataset
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# prepare dataset
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error("No data found. Please verify the metadata file and train_data_dir option.")
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
# prepare accelerator
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precision dtype
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# Load tokenizers and set strategies
|
||||
logger.info("Loading tokenizers...")
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
|
||||
|
||||
# Set tokenize strategy
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_token_length,
|
||||
t5_max_length=args.t5_max_token_length,
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# Prepare text encoder (always frozen for Anima)
|
||||
qwen3_text_encoder.to(weight_dtype)
|
||||
qwen3_text_encoder.requires_grad_(False)
|
||||
|
||||
# Cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([qwen3_text_encoder], accelerator)
|
||||
|
||||
# cache sample prompt embeddings
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"Cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {}
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# free text encoder memory
|
||||
qwen3_text_encoder = None
|
||||
gc.collect() # Force garbage collection to free memory
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Load VAE and cache latents
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_model(
|
||||
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
dit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing,
|
||||
unsloth_offload=args.unsloth_offload_checkpointing,
|
||||
)
|
||||
|
||||
train_dit = args.learning_rate != 0
|
||||
dit.requires_grad_(train_dit)
|
||||
if not train_dit:
|
||||
dit.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Block swap
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if is_swapping_blocks:
|
||||
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Setup optimizer with parameter groups
|
||||
if train_dit:
|
||||
param_groups = anima_train_utils.get_anima_param_groups(
|
||||
dit,
|
||||
base_lr=args.learning_rate,
|
||||
self_attn_lr=args.self_attn_lr,
|
||||
cross_attn_lr=args.cross_attn_lr,
|
||||
mlp_lr=args.mlp_lr,
|
||||
mod_lr=args.mod_lr,
|
||||
llm_adapter_lr=args.llm_adapter_lr,
|
||||
)
|
||||
else:
|
||||
param_groups = []
|
||||
|
||||
training_models = []
|
||||
if train_dit:
|
||||
training_models.append(dit)
|
||||
|
||||
# calculate trainable parameters
|
||||
n_params = 0
|
||||
for group in param_groups:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train dit: {train_dit}")
|
||||
accelerator.print(f"number of training models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params:,}")
|
||||
|
||||
# prepare optimizer
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# Pass per-component param_groups directly to preserve per-component LRs
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
# prepare dataloader
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count())
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# calculate training steps
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs: {args.max_train_steps}")
|
||||
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr scheduler
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# full fp16/bf16 training
|
||||
dit_weight_dtype = weight_dtype
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
elif args.full_bf16:
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Prepare with accelerator
|
||||
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
||||
# if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
# qwen3_text_encoder.to("cpu")
|
||||
# if not cache_latents and vae is not None:
|
||||
# vae.to("cpu")
|
||||
# clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
if train_dit:
|
||||
dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks])
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# Move non-training models back to GPU
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
if not cache_latents and vae is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resume
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
||||
|
||||
# Training loop
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
|
||||
wandb.define_metric("epoch")
|
||||
wandb.define_metric("loss/epoch", step_metric="epoch")
|
||||
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
# Show model info
|
||||
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
||||
if unwrapped_dit is not None:
|
||||
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
|
||||
if qwen3_text_encoder is not None:
|
||||
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
|
||||
if vae is not None:
|
||||
logger.info(f"vae device: {vae.device}")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
# Get latents
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
|
||||
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
# Get text encoder outputs
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Cached outputs
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
||||
else:
|
||||
# Encode on-the-fly
|
||||
input_ids_list = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_text_encoder], input_ids_list
|
||||
)
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Noise and timesteps
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# NaN checks
|
||||
if torch.any(torch.isnan(noisy_model_input)):
|
||||
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
||||
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
||||
|
||||
# Create padding mask
|
||||
# padding_mask: (B, 1, H_latent, W_latent)
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
|
||||
|
||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
|
||||
with accelerator.autocast():
|
||||
model_pred = dit(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
t5_input_ids=t5_input_ids,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
|
||||
|
||||
# Compute loss (rectified flow: target = noise - latents)
|
||||
target = noise - latents
|
||||
|
||||
# Weighting
|
||||
weighting = anima_train_utils.compute_loss_weighting_for_anima(
|
||||
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
||||
)
|
||||
|
||||
# Loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
|
||||
loss_weights = batch["loss_weights"]
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not args.fused_backward_pass:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# Save at specific steps
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(dit) if train_dit else None,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs_with_names(
|
||||
logs,
|
||||
lr_scheduler,
|
||||
args.optimizer_type,
|
||||
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(dit) if train_dit else None,
|
||||
)
|
||||
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# End training
|
||||
is_main_process = accelerator.is_main_process
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator
|
||||
|
||||
if is_main_process and train_dit:
|
||||
anima_train_utils.save_anima_model_on_train_end(
|
||||
args,
|
||||
save_dtype,
|
||||
epoch,
|
||||
global_step,
|
||||
dit,
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser)
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="offload gradient checkpointing to CPU (reduces VRAM at cost of speed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unsloth_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
||||
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
train(args)
|
||||
451
anima_train_network.py
Normal file
451
anima_train_network.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# Anima LoRA training script
|
||||
|
||||
import argparse
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import Accelerator
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import (
|
||||
anima_models,
|
||||
anima_train_utils,
|
||||
anima_utils,
|
||||
flux_train_utils,
|
||||
qwen_image_autoencoder_kl,
|
||||
sd3_train_utils,
|
||||
strategy_anima,
|
||||
strategy_base,
|
||||
train_util,
|
||||
)
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
|
||||
args.fp8_base = False
|
||||
args.fp8_base_unet = False
|
||||
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
assert (
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert (
|
||||
not args.cpu_offload_checkpointing
|
||||
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(16)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
|
||||
logger.info("Loading Qwen3 text encoder...")
|
||||
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
# Load VAE
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
vae.to(weight_dtype)
|
||||
vae.eval()
|
||||
|
||||
# Return format: (model_type, text_encoders, vae, unet)
|
||||
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
|
||||
|
||||
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
|
||||
loading_dtype = None if args.fp8_scaled else weight_dtype
|
||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||
|
||||
attn_mode = "torch"
|
||||
if args.xformers:
|
||||
attn_mode = "xformers"
|
||||
if args.attn_mode is not None:
|
||||
attn_mode = args.attn_mode
|
||||
|
||||
# Load DiT
|
||||
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
|
||||
model = anima_utils.load_anima_model(
|
||||
accelerator.device,
|
||||
args.pretrained_model_name_or_path,
|
||||
attn_mode,
|
||||
args.split_attn,
|
||||
loading_device,
|
||||
loading_dtype,
|
||||
args.fp8_scaled,
|
||||
)
|
||||
|
||||
# Store unsloth preference so that when the base NetworkTrainer calls
|
||||
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
|
||||
# The base trainer only passes cpu_offload, so we store the flag on the model.
|
||||
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
|
||||
|
||||
# Block swap
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
return model, text_encoders
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_path=args.qwen3,
|
||||
t5_tokenizer_path=args.t5_tokenizer_path,
|
||||
qwen3_max_length=args.qwen3_max_token_length,
|
||||
t5_max_length=args.t5_max_token_length,
|
||||
)
|
||||
return tokenize_strategy
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
|
||||
return [tokenize_strategy.qwen3_tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_anima.AnimaTextEncodingStrategy()
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
pass
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return None # no text encoders needed for encoding
|
||||
return text_encoders
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# We cannot move DiT to CPU because of block swap, so only move VAE
|
||||
logger.info("move vae to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
logger.info("move text encoder to gpu")
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {}
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, text_encoders, tokens_and_masks
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move text encoder back to cpu
|
||||
logger.info("move text encoder back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae back to original device")
|
||||
vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
# move text encoder to device for encoding during training/validation
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
|
||||
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
qwen3_te = te[0] if te is not None else None
|
||||
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
unet,
|
||||
vae,
|
||||
qwen3_te,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
|
||||
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
# Latents already normalized by vae.encode with scale
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
anima: anima_models.Anima = unet
|
||||
|
||||
# Sample noise
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# Gradient checkpointing support
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack text encoder conditions
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds[
|
||||
:4
|
||||
] # ignore caption_dropout_rate which is not needed for training step
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Create padding mask
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
|
||||
|
||||
# Call model
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
model_pred = anima(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
|
||||
|
||||
# Rectified flow target: noise - latents
|
||||
target = noise - latents
|
||||
|
||||
# Loss weighting
|
||||
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
"""Override base process_batch for caption dropout with cached text encoder outputs."""
|
||||
|
||||
# Text encoder conditions
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
|
||||
if text_encoder_outputs_list is not None:
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
# Add the caption dropout rates back to the list for validation dataset (which is re-used batch items)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list + [caption_dropout_rates]
|
||||
|
||||
return super().process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train,
|
||||
train_text_encoder,
|
||||
train_unet,
|
||||
)
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
# Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
|
||||
first_param = next(text_encoder.parameters())
|
||||
first_param.requires_grad_(True)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
# The base NetworkTrainer only calls enable_gradient_checkpointing(cpu_offload=True/False),
|
||||
# so we re-apply with unsloth_offload if needed (after base has already enabled it).
|
||||
if self._use_unsloth_offload_checkpointing and args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing(unsloth_offload=True)
|
||||
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
model = unet
|
||||
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
|
||||
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
|
||||
|
||||
return model
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument(
|
||||
"--unsloth_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
||||
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
trainer = AnimaNetworkTrainer()
|
||||
trainer.train(args)
|
||||
30
configs/qwen3_06b/config.json
Normal file
30
configs/qwen3_06b/config.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151643,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 32768,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": null,
|
||||
"tie_word_embeddings": true,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": true,
|
||||
"use_sliding_window": false,
|
||||
"vocab_size": 151936
|
||||
}
|
||||
151388
configs/qwen3_06b/merges.txt
Normal file
151388
configs/qwen3_06b/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
303282
configs/qwen3_06b/tokenizer.json
Normal file
303282
configs/qwen3_06b/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
239
configs/qwen3_06b/tokenizer_config.json
Normal file
239
configs/qwen3_06b/tokenizer_config.json
Normal file
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"151643": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151644": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151645": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151646": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151647": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151648": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151649": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151650": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151651": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151652": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151653": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151654": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151655": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151656": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151657": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151658": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151659": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151660": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151661": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151662": {
|
||||
"content": "<|fim_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151663": {
|
||||
"content": "<|repo_name|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151664": {
|
||||
"content": "<|file_sep|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151666": {
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151667": {
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151668": {
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"bos_token": null,
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"split_special_tokens": false,
|
||||
"tokenizer_class": "Qwen2Tokenizer",
|
||||
"unk_token": null
|
||||
}
|
||||
1
configs/qwen3_06b/vocab.json
Normal file
1
configs/qwen3_06b/vocab.json
Normal file
File diff suppressed because one or more lines are too long
51
configs/t5_old/config.json
Normal file
51
configs/t5_old/config.json
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"architectures": [
|
||||
"T5WithLMHeadModel"
|
||||
],
|
||||
"d_ff": 65536,
|
||||
"d_kv": 128,
|
||||
"d_model": 1024,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"n_positions": 512,
|
||||
"num_heads": 128,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"task_specific_params": {
|
||||
"summarization": {
|
||||
"early_stopping": true,
|
||||
"length_penalty": 2.0,
|
||||
"max_length": 200,
|
||||
"min_length": 30,
|
||||
"no_repeat_ngram_size": 3,
|
||||
"num_beams": 4,
|
||||
"prefix": "summarize: "
|
||||
},
|
||||
"translation_en_to_de": {
|
||||
"early_stopping": true,
|
||||
"max_length": 300,
|
||||
"num_beams": 4,
|
||||
"prefix": "translate English to German: "
|
||||
},
|
||||
"translation_en_to_fr": {
|
||||
"early_stopping": true,
|
||||
"max_length": 300,
|
||||
"num_beams": 4,
|
||||
"prefix": "translate English to French: "
|
||||
},
|
||||
"translation_en_to_ro": {
|
||||
"early_stopping": true,
|
||||
"max_length": 300,
|
||||
"num_beams": 4,
|
||||
"prefix": "translate English to Romanian: "
|
||||
}
|
||||
},
|
||||
"vocab_size": 32128
|
||||
}
|
||||
BIN
configs/t5_old/spiece.model
Normal file
BIN
configs/t5_old/spiece.model
Normal file
Binary file not shown.
1
configs/t5_old/tokenizer.json
Normal file
1
configs/t5_old/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
655
docs/anima_train_network.md
Normal file
655
docs/anima_train_network.md
Normal file
@@ -0,0 +1,655 @@
|
||||
# LoRA Training Guide for Anima using `anima_train_network.py` / `anima_train_network.py` を用いたAnima モデルのLoRA学習ガイド
|
||||
|
||||
This document explains how to train LoRA (Low-Rank Adaptation) models for Anima using `anima_train_network.py` in the `sd-scripts` repository.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
このドキュメントでは、`sd-scripts`リポジトリに含まれる`anima_train_network.py`を使用して、Anima モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
|
||||
|
||||
</details>
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
|
||||
|
||||
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
|
||||
|
||||
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* The `sd-scripts` repository has been cloned and the Python environment is ready.
|
||||
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
|
||||
* Anima model files for training are available.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||
|
||||
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
**前提条件:**
|
||||
|
||||
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
|
||||
* 学習対象のAnimaモデルファイルが準備できていること。
|
||||
</details>
|
||||
|
||||
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||
|
||||
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
|
||||
|
||||
* **Target models:** Anima DiT models.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
|
||||
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
|
||||
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** Anima DiTモデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
|
||||
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
|
||||
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
|
||||
The following files are required before starting training:
|
||||
|
||||
1. **Training script:** `anima_train_network.py`
|
||||
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
|
||||
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
|
||||
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
|
||||
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
|
||||
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
|
||||
|
||||
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
|
||||
|
||||
**Notes:**
|
||||
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習を開始する前に、以下のファイルが必要です。
|
||||
|
||||
1. **学習スクリプト:** `anima_train_network.py`
|
||||
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
|
||||
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
|
||||
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||
5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
|
||||
6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
|
||||
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
|
||||
|
||||
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
|
||||
|
||||
**注意:**
|
||||
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
|
||||
Execute `anima_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Anima specific options must be supplied.
|
||||
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--pretrained_model_name_or_path="<path to Anima DiT model>" \
|
||||
--qwen3="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae="<path to Qwen-Image VAE model>" \
|
||||
--dataset_config="my_anima_dataset_config.toml" \
|
||||
--output_dir="<output directory>" \
|
||||
--output_name="my_anima_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_anima \
|
||||
--network_dim=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW8bit" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sampling="sigmoid" \
|
||||
--discrete_flow_shift=1.0 \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs \
|
||||
--vae_chunk_size=64 \
|
||||
--vae_disable_cache
|
||||
```
|
||||
|
||||
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||
|
||||
The learning rate of `1e-4` is just an example. Adjust it according to your dataset and objectives. This value is for `alpha=1.0` (default). If increasing `--network_alpha`, consider lowering the learning rate.
|
||||
|
||||
If loss becomes NaN, ensure you are using PyTorch version 2.5 or higher.
|
||||
|
||||
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習は、ターミナルから`anima_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Anima特有の引数を指定する必要があります。
|
||||
|
||||
コマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
|
||||
学習率1e-4はあくまで一例です。データセットや目的に応じて適切に調整してください。またこの値はalpha=1.0(デフォルト)での値です。`--network_alpha`を増やす場合は学習率を下げることを検討してください。
|
||||
|
||||
lossがNaNになる場合は、PyTorchのバージョンが2.5以上であることを確認してください。
|
||||
|
||||
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
|
||||
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
|
||||
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Anima specific options. For shared options (`--output_dir`, `--output_name`, `--network_module`, etc.), see that guide.
|
||||
|
||||
#### Model Options [Required] / モデル関連 [必須]
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[Required]**
|
||||
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
|
||||
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
|
||||
#### Model Options [Optional] / モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
|
||||
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
|
||||
- Path to the T5 tokenizer directory. If omitted, uses the bundled config at `configs/t5_old/`.
|
||||
|
||||
#### Anima Training Parameters / Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sampling=<choice>`
|
||||
- Timestep sampling method. Choose from `sigma`, `uniform`, `sigmoid` (default), `shift`, `flux_shift`. Same options as FLUX training. See the [flux_train_network.py guide](flux_train_network.md) for details on each method.
|
||||
* `--discrete_flow_shift=<float>`
|
||||
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. This value is used when `--timestep_sampling` is set to **`shift`**. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
* `--sigmoid_scale=<float>`
|
||||
- Scale factor when `--timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default `1.0`.
|
||||
* `--qwen3_max_token_length=<integer>`
|
||||
- Maximum token length for the Qwen3 tokenizer. Default `512`.
|
||||
* `--t5_max_token_length=<integer>`
|
||||
- Maximum token length for the T5 tokenizer. Default `512`.
|
||||
* `--attn_mode=<choice>`
|
||||
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
|
||||
* `--split_attn`
|
||||
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
|
||||
|
||||
#### Component-wise Learning Rates / コンポーネント別学習率
|
||||
|
||||
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
|
||||
|
||||
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
|
||||
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
|
||||
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
|
||||
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`. Note: modulation layers are not included in LoRA by default.
|
||||
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
|
||||
|
||||
For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Section 5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御).
|
||||
|
||||
#### Memory and Speed / メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>`
|
||||
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
|
||||
- 28-block model: max **26** (Anima-Preview)
|
||||
- 36-block model: max **34**
|
||||
- 20-block model: max **18**
|
||||
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
|
||||
* `--unsloth_offload_checkpointing`
|
||||
- Offload activations to CPU RAM using async non-blocking transfers (faster than `--cpu_offload_checkpointing`). Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
|
||||
* `--cache_text_encoder_outputs_to_disk`
|
||||
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
- Cache Qwen-Image VAE latent outputs.
|
||||
* `--vae_chunk_size=<integer>`
|
||||
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
|
||||
* `--vae_disable_cache`
|
||||
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
|
||||
|
||||
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
|
||||
* `--fp8_base` - Not supported for Anima. If specified, it will be disabled with a warning.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のAnima特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
|
||||
|
||||
#### モデル関連 [必須]
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
|
||||
|
||||
#### モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
|
||||
|
||||
#### Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sampling` - タイムステップのサンプリング方法。`sigma`、`uniform`、`sigmoid`(デフォルト)、`shift`、`flux_shift`から選択。FLUX学習と同じオプションです。各方法の詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`1.0`。`--timestep_sampling`が`shift`の場合に使用されます。
|
||||
* `--sigmoid_scale` - `sigmoid`、`shift`、`flux_shift`タイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--attn_mode` - 使用するAttentionの実装。`torch`(デフォルト)、`xformers`、`flash`、`sageattn`から選択。`xformers`は`--split_attn`の指定が必要です。`sageattn`はトレーニングをサポートしていません(推論のみ)。
|
||||
* `--split_attn` - メモリ使用量を減らすためにattention時にバッチを分割します。`--attn_mode xformers`使用時に必要です。
|
||||
|
||||
#### コンポーネント別学習率
|
||||
|
||||
これらのオプションは、Animaモデルの各コンポーネントに個別の学習率を設定します。主にフルファインチューニング用です。`0`に設定するとそのコンポーネントをフリーズします:
|
||||
|
||||
* `--self_attn_lr` - Self-attention層の学習率。
|
||||
* `--cross_attn_lr` - Cross-attention層の学習率。
|
||||
* `--mlp_lr` - MLP層の学習率。
|
||||
* `--mod_lr` - AdaLNモジュレーション層の学習率。モジュレーション層はデフォルトではLoRAに含まれません。
|
||||
* `--llm_adapter_lr` - LLM Adapter層の学習率。
|
||||
|
||||
LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してください。[セクション5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御)を参照。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
|
||||
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
|
||||
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
|
||||
|
||||
#### 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Stable Diffusion v1/v2向けの引数。Animaの学習では使用されません。
|
||||
* `--fp8_base` - Animaではサポートされていません。指定した場合、警告とともに無効化されます。
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
|
||||
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
必要な引数を設定したら、コマンドを実行して学習を開始します。全体の流れやログの確認方法は、[train_network.pyのガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||
|
||||
</details>
|
||||
|
||||
## 5. LoRA Target Modules / LoRAの学習対象モジュール
|
||||
|
||||
When training LoRA with `anima_train_network.py`, the following modules are targeted by default:
|
||||
|
||||
* **DiT Blocks (`Block`)**: Self-attention (`self_attn`), cross-attention (`cross_attn`), and MLP (`mlp`) layers within each transformer block. Modulation (`adaln_modulation`), norm, embedder, and final layers are excluded by default.
|
||||
* **Embedding layers (`PatchEmbed`, `TimestepEmbedding`) and Final layer (`FinalLayer`)**: Excluded by default but can be included using `include_patterns`.
|
||||
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
|
||||
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified and `--cache_text_encoder_outputs` is NOT used.
|
||||
|
||||
The LoRA network module is `networks.lora_anima`.
|
||||
|
||||
### 5.1. Module Selection with Patterns / パターンによるモジュール選択
|
||||
|
||||
By default, the following modules are excluded from LoRA via the built-in exclude pattern:
|
||||
```
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
You can customize which modules are included or excluded using regex patterns in `--network_args`:
|
||||
|
||||
* `exclude_patterns` - Exclude modules matching these patterns (in addition to the default exclusion).
|
||||
* `include_patterns` - Force-include modules matching these patterns, overriding exclusion.
|
||||
|
||||
Patterns are matched against the full module name using `re.fullmatch()`.
|
||||
|
||||
Example to include the final layer:
|
||||
```
|
||||
--network_args "include_patterns=['.*final_layer.*']"
|
||||
```
|
||||
|
||||
Example to additionally exclude MLP layers:
|
||||
```
|
||||
--network_args "exclude_patterns=['.*mlp.*']"
|
||||
```
|
||||
|
||||
### 5.2. Regex-based Rank and Learning Rate Control / 正規表現によるランク・学習率の制御
|
||||
|
||||
You can specify different ranks (network_dim) and learning rates for modules matching specific regex patterns:
|
||||
|
||||
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
|
||||
* Example: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* This sets the rank to 8 for self-attention modules, 4 for cross-attention modules, and 8 for MLP modules.
|
||||
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
|
||||
* Example: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
* This sets the learning rate to `1e-4` for self-attention modules and `5e-5` for cross-attention modules.
|
||||
|
||||
**Notes:**
|
||||
|
||||
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
|
||||
* Patterns are matched using `re.fullmatch()` against the module's original name (e.g., `blocks.0.self_attn.q_proj`).
|
||||
|
||||
### 5.3. LLM Adapter LoRA / LLM Adapter LoRA
|
||||
|
||||
To apply LoRA to the LLM Adapter blocks:
|
||||
|
||||
```
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
In preliminary tests, lowering the learning rate for the LLM Adapter seems to improve stability. Adjust it using something like: `"network_reg_lrs=.*llm_adapter.*=5e-5"`.
|
||||
|
||||
### 5.4. Other Network Args / その他のネットワーク引数
|
||||
|
||||
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
|
||||
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
|
||||
* `--network_args "module_dropout=0.1"` - Module dropout rate.
|
||||
* `--network_args "loraplus_lr_ratio=2.0"` - LoRA+ learning rate ratio.
|
||||
* `--network_args "loraplus_unet_lr_ratio=2.0"` - LoRA+ learning rate ratio for DiT only.
|
||||
* `--network_args "loraplus_text_encoder_lr_ratio=2.0"` - LoRA+ learning rate ratio for text encoder only.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
|
||||
|
||||
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention(`self_attn`)、Cross-attention(`cross_attn`)、MLP(`mlp`)層。モジュレーション(`adaln_modulation`)、norm、embedder、final layerはデフォルトで除外されます。
|
||||
* **埋め込み層 (`PatchEmbed`, `TimestepEmbedding`) と最終層 (`FinalLayer`)**: デフォルトで除外されますが、`include_patterns`で含めることができます。
|
||||
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
|
||||
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定せず、かつ`--cache_text_encoder_outputs`を使用しない場合のみ。
|
||||
|
||||
### 5.1. パターンによるモジュール選択
|
||||
|
||||
デフォルトでは以下のモジュールが組み込みの除外パターンによりLoRAから除外されます:
|
||||
```
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
`--network_args`で正規表現パターンを使用して、含めるモジュールと除外するモジュールをカスタマイズできます:
|
||||
|
||||
* `exclude_patterns` - これらのパターンにマッチするモジュールを除外(デフォルトの除外に追加)。
|
||||
* `include_patterns` - これらのパターンにマッチするモジュールを強制的に含める(除外を上書き)。
|
||||
|
||||
パターンは`re.fullmatch()`を使用して完全なモジュール名に対してマッチングされます。
|
||||
|
||||
### 5.2. 正規表現によるランク・学習率の制御
|
||||
|
||||
正規表現にマッチするモジュールに対して、異なるランクや学習率を指定できます:
|
||||
|
||||
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
|
||||
**注意点:**
|
||||
* `network_reg_dims`および`network_reg_lrs`での設定は、全体設定である`--network_dim`や`--learning_rate`よりも優先されます。
|
||||
* パターンはモジュールのオリジナル名(例: `blocks.0.self_attn.q_proj`)に対して`re.fullmatch()`でマッチングされます。
|
||||
|
||||
### 5.3. LLM Adapter LoRA
|
||||
|
||||
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True"`
|
||||
|
||||
簡易な検証ではLLM Adapterの学習率はある程度下げた方が安定するようです。`"network_reg_lrs=.*llm_adapter.*=5e-5"`などで調整してください。
|
||||
|
||||
### 5.4. その他のネットワーク引数
|
||||
|
||||
* `verbose=True` - 全LoRAモジュール名とdimを表示
|
||||
* `rank_dropout` - ランクドロップアウト率
|
||||
* `module_dropout` - モジュールドロップアウト率
|
||||
* `loraplus_lr_ratio` - LoRA+学習率比率
|
||||
* `loraplus_unet_lr_ratio` - DiT専用のLoRA+学習率比率
|
||||
* `loraplus_text_encoder_lr_ratio` - テキストエンコーダー専用のLoRA+学習率比率
|
||||
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima, such as ComfyUI with appropriate nodes.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_anima_lora.safetensors`)が保存されます。このファイルは、Anima モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
|
||||
|
||||
</details>
|
||||
|
||||
## 7. Advanced Settings / 高度な設定
|
||||
|
||||
### 7.1. VRAM Usage Optimization / VRAM使用量の最適化
|
||||
|
||||
Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
|
||||
#### Key VRAM Reduction Options
|
||||
|
||||
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
|
||||
|
||||
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
|
||||
|
||||
- **`--gradient_checkpointing`**: Standard gradient checkpointing to reduce VRAM at the cost of compute.
|
||||
|
||||
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
|
||||
|
||||
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
|
||||
|
||||
- **Using Adafactor optimizer**: Can reduce VRAM usage:
|
||||
```
|
||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
|
||||
|
||||
主要なVRAM削減オプション:
|
||||
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
|
||||
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
|
||||
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
|
||||
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
|
||||
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
|
||||
- Adafactorオプティマイザの使用
|
||||
|
||||
</details>
|
||||
|
||||
### 7.2. Training Settings / 学習設定
|
||||
|
||||
#### Timestep Sampling
|
||||
|
||||
The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
|
||||
|
||||
- `sigma`: Sigma-based sampling like SD3.
|
||||
- `uniform`: Uniform random sampling from [0, 1].
|
||||
- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||
- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
- `flux_shift`: Resolution-dependent shift used in FLUX training.
|
||||
|
||||
See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
|
||||
|
||||
#### Discrete Flow Shift
|
||||
|
||||
The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
|
||||
|
||||
```
|
||||
t_shifted = (t * shift) / (1 + (shift - 1) * t)
|
||||
```
|
||||
|
||||
#### Loss Weighting
|
||||
|
||||
The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||
|
||||
- `uniform` (default): Equal weight for all timesteps.
|
||||
- `sigma_sqrt`: Weight by `sigma^(-2)`.
|
||||
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
|
||||
- `none`: Same as uniform.
|
||||
- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
|
||||
|
||||
#### Caption Dropout
|
||||
|
||||
Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
|
||||
|
||||
**If you change the `caption_dropout_rate` setting, you must delete and regenerate the cache.**
|
||||
|
||||
Note: Currently, only Anima supports combining `caption_dropout_rate` with text encoder output caching.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
#### タイムステップサンプリング
|
||||
|
||||
`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます:
|
||||
|
||||
- `sigma`: SD3と同様のシグマベースサンプリング。
|
||||
- `uniform`: [0, 1]の一様分布からサンプリング。
|
||||
- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
|
||||
- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
|
||||
- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
|
||||
|
||||
詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
|
||||
#### 離散フローシフト
|
||||
|
||||
`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling`が`shift`の場合のみ適用されます。
|
||||
|
||||
#### 損失の重み付け
|
||||
|
||||
`--weighting_scheme`でタイムステップごとの損失の重み付けを指定します。
|
||||
|
||||
#### キャプションドロップアウト
|
||||
|
||||
キャプションドロップアウトにはデータセット設定(TOMLでのサブセット単位)の`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュと同時に使用できます。
|
||||
|
||||
**`caption_dropout_rate`の設定を変えた場合、キャッシュを削除し、再生成する必要があります。**
|
||||
|
||||
※`caption_dropout_rate`をテキストエンコーダー出力キャッシュと組み合わせられるのは、今のところAnimaのみです。
|
||||
|
||||
</details>
|
||||
|
||||
### 7.3. Text Encoder LoRA Support / Text Encoder LoRAのサポート
|
||||
|
||||
Anima LoRA training supports training Qwen3 text encoder LoRA:
|
||||
|
||||
- To train only DiT: specify `--network_train_unet_only`
|
||||
- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
|
||||
|
||||
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
|
||||
|
||||
Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
|
||||
|
||||
- DiTのみ学習: `--network_train_unet_only`を指定
|
||||
- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
|
||||
|
||||
Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
|
||||
|
||||
注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
|
||||
|
||||
</details>
|
||||
|
||||
## 8. Other Training Options / その他の学習オプション
|
||||
|
||||
- **`--loss_type`**: Loss function for training. Default `l2`.
|
||||
- `l1`: L1 loss.
|
||||
- `l2`: L2 loss (mean squared error).
|
||||
- `huber`: Huber loss.
|
||||
- `smooth_l1`: Smooth L1 loss.
|
||||
|
||||
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Parameters for Huber loss when `--loss_type` is `huber` or `smooth_l1`.
|
||||
|
||||
- **`--ip_noise_gamma`**, **`--ip_noise_gamma_random_strength`**: Input Perturbation noise gamma values.
|
||||
|
||||
- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. Only works with Adafactor. For details, see the [`sdxl_train_network.py` guide](sdxl_train_network.md).
|
||||
|
||||
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: Timestep loss weighting options. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
- **`--loss_type`**: 学習に用いる損失関数。デフォルト`l2`。`l1`, `l2`, `huber`, `smooth_l1`から選択。
|
||||
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Huber損失のパラメータ。
|
||||
- **`--ip_noise_gamma`**: Input Perturbationノイズガンマ値。
|
||||
- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップの融合。
|
||||
- **`--weighting_scheme`** 等: タイムステップ損失の重み付け。詳細は[`sd3_train_network.md`](sd3_train_network.md)を参照。
|
||||
|
||||
</details>
|
||||
|
||||
## 9. Related Tools / 関連ツール
|
||||
|
||||
### `networks/anima_convert_lora_to_comfy.py`
|
||||
|
||||
A script to convert LoRA models to ComfyUI-compatible format. ComfyUI does not directly support sd-scripts format Qwen3 LoRA, so conversion is necessary (conversion may not be needed for DiT-only LoRA). You can convert from the sd-scripts format to ComfyUI format with:
|
||||
|
||||
```bash
|
||||
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
|
||||
```
|
||||
|
||||
Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
**`networks/convert_anima_lora_to_comfy.py`**
|
||||
|
||||
LoRAモデルをComfyUI互換形式に変換するスクリプト。ComfyUIがsd-scripts形式のQwen3 LoRAを直接サポートしていないため、変換が必要です(DiTのみのLoRAの場合は変換不要のようです)。sd-scripts形式からComfyUI形式への変換は以下のコマンドで行います:
|
||||
|
||||
```bash
|
||||
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
|
||||
```
|
||||
|
||||
`--reverse`オプションを付けると、逆変換(ComfyUI形式からsd-scripts形式)も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## 10. Others / その他
|
||||
|
||||
### Metadata Saved in LoRA Models
|
||||
|
||||
The following metadata is saved in the LoRA model file:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python anima_train_network.py --help`) を参照してください。
|
||||
|
||||
### LoRAモデルに保存されるメタデータ
|
||||
|
||||
以下のメタデータがLoRAモデルファイルに保存されます:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
</details>
|
||||
@@ -122,11 +122,15 @@ These are options related to the configuration of the data set. They cannot be d
|
||||
| `max_bucket_reso` | `1024` | o | o |
|
||||
| `min_bucket_reso` | `128` | o | o |
|
||||
| `resolution` | `256`, `[512, 512]` | o | o |
|
||||
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
|
||||
|
||||
* `batch_size`
|
||||
* This corresponds to the command-line argument `--train_batch_size`.
|
||||
* `max_bucket_reso`, `min_bucket_reso`
|
||||
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
|
||||
* `skip_image_resolution`
|
||||
* Images whose original resolution (area) is equal to or smaller than the specified resolution will be skipped. Specify as `'size'` or `[width, height]`. This corresponds to the command-line argument `--skip_image_resolution`.
|
||||
* Useful when sharing the same image directory across multiple datasets with different resolutions, to exclude low-resolution source images from higher-resolution datasets.
|
||||
|
||||
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
|
||||
|
||||
@@ -254,6 +258,34 @@ resolution = 768
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
When using multi-resolution datasets, you can use `skip_image_resolution` to exclude images whose original size is too small for higher-resolution datasets. This prevents overlapping of low-resolution images across datasets and improves training quality. This option can also be used to simply exclude low-resolution source images from datasets.
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
max_bucket_reso = 1536
|
||||
|
||||
[[datasets]]
|
||||
resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1024
|
||||
skip_image_resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1280
|
||||
skip_image_resolution = 1024
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
In this example, the 1024-resolution dataset skips images whose original size is 768x768 or smaller, and the 1280-resolution dataset skips images whose original size is 1024x1024 or smaller.
|
||||
|
||||
## Command Line Argument and Configuration File
|
||||
|
||||
There are options in the configuration file that have overlapping roles with command line argument options.
|
||||
@@ -284,6 +316,7 @@ For the command line options listed below, if an option is specified in both the
|
||||
| `--random_crop` | |
|
||||
| `--resolution` | |
|
||||
| `--shuffle_caption` | |
|
||||
| `--skip_image_resolution` | |
|
||||
| `--train_batch_size` | `batch_size` |
|
||||
|
||||
## Error Guide
|
||||
|
||||
@@ -115,11 +115,15 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `max_bucket_reso` | `1024` | o | o |
|
||||
| `min_bucket_reso` | `128` | o | o |
|
||||
| `resolution` | `256`, `[512, 512]` | o | o |
|
||||
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
|
||||
|
||||
* `batch_size`
|
||||
* コマンドライン引数の `--train_batch_size` と同等です。
|
||||
* `max_bucket_reso`, `min_bucket_reso`
|
||||
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
|
||||
* `skip_image_resolution`
|
||||
* 指定した解像度(面積)以下の画像をスキップします。`'サイズ'` または `[幅, 高さ]` で指定します。コマンドライン引数の `--skip_image_resolution` と同等です。
|
||||
* 同じ画像ディレクトリを異なる解像度の複数のデータセットで使い回す場合に、低解像度の元画像を高解像度のデータセットから除外するために使用します。
|
||||
|
||||
これらの設定はデータセットごとに固定です。
|
||||
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
|
||||
@@ -259,6 +263,34 @@ resolution = 768
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
なお、マルチ解像度データセットでは `skip_image_resolution` を使用して、元の画像サイズが小さい画像を高解像度データセットから除外できます。これにより、低解像度画像のデータセット間での重複を防ぎ、学習品質を向上させることができます。また、小さい画像を除外するフィルターとしても機能します。
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
max_bucket_reso = 1536
|
||||
|
||||
[[datasets]]
|
||||
resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1024
|
||||
skip_image_resolution = 768
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 1280
|
||||
skip_image_resolution = 1024
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
この例では、1024 解像度のデータセットでは元の画像サイズが 768x768 以下の画像がスキップされ、1280 解像度のデータセットでは 1024x1024 以下の画像がスキップされます。
|
||||
|
||||
## コマンドライン引数との併用
|
||||
|
||||
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
|
||||
@@ -289,6 +321,7 @@ resolution = 768
|
||||
| `--random_crop` | |
|
||||
| `--resolution` | |
|
||||
| `--shuffle_caption` | |
|
||||
| `--skip_image_resolution` | |
|
||||
| `--train_batch_size` | `batch_size` |
|
||||
|
||||
## エラーの手引き
|
||||
|
||||
359
docs/loha_lokr.md
Normal file
359
docs/loha_lokr.md
Normal file
@@ -0,0 +1,359 @@
|
||||
> 📝 Click on the language section to expand / 言語をクリックして展開
|
||||
|
||||
# LoHa / LoKr (LyCORIS)
|
||||
|
||||
## Overview / 概要
|
||||
|
||||
In addition to standard LoRA, sd-scripts supports **LoHa** (Low-rank Hadamard Product) and **LoKr** (Low-rank Kronecker Product) as alternative parameter-efficient fine-tuning methods. These are based on techniques from the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project.
|
||||
|
||||
- **LoHa**: Represents weight updates as a Hadamard (element-wise) product of two low-rank matrices. Reference: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
|
||||
- **LoKr**: Represents weight updates as a Kronecker product with optional low-rank decomposition. Reference: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
|
||||
|
||||
The algorithms and recommended settings are described in the [LyCORIS documentation](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md) and [guidelines](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md).
|
||||
|
||||
Both methods target Linear and Conv2d layers. Conv2d 1x1 layers are treated similarly to Linear layers. For Conv2d 3x3+ layers, optional Tucker decomposition or flat (kernel-flattened) mode is available.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
sd-scriptsでは、標準的なLoRAに加え、代替のパラメータ効率の良いファインチューニング手法として **LoHa**(Low-rank Hadamard Product)と **LoKr**(Low-rank Kronecker Product)をサポートしています。これらは [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) プロジェクトの手法に基づいています。
|
||||
|
||||
- **LoHa**: 重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します。参考文献: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
|
||||
- **LoKr**: 重みの更新をKronecker積と、オプションの低ランク分解で表現します。参考文献: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
|
||||
|
||||
アルゴリズムと推奨設定は[LyCORISのアルゴリズム解説](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md)と[ガイドライン](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md)を参照してください。
|
||||
|
||||
LinearおよびConv2d層の両方を対象としています。Conv2d 1x1層はLinear層と同様に扱われます。Conv2d 3x3+層については、オプションのTucker分解またはflat(カーネル平坦化)モードが利用可能です。
|
||||
|
||||
この機能は実験的なものです。
|
||||
|
||||
</details>
|
||||
|
||||
## Acknowledgments / 謝辞
|
||||
|
||||
The LoHa and LoKr implementations in sd-scripts are based on the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project by [KohakuBlueleaf](https://github.com/KohakuBlueleaf). We would like to express our sincere gratitude for the excellent research and open-source contributions that made this implementation possible.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
sd-scriptsのLoHaおよびLoKrの実装は、[KohakuBlueleaf](https://github.com/KohakuBlueleaf)氏による[LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS)プロジェクトに基づいています。この実装を可能にしてくださった素晴らしい研究とオープンソースへの貢献に心から感謝いたします。
|
||||
|
||||
</details>
|
||||
|
||||
## Supported architectures / 対応アーキテクチャ
|
||||
|
||||
LoHa and LoKr automatically detect the model architecture and apply appropriate default settings. The following architectures are currently supported:
|
||||
|
||||
- **SDXL**: Targets `Transformer2DModel` for UNet and `CLIPAttention`/`CLIPMLP` for text encoders. Conv2d layers in `ResnetBlock2D`, `Downsample2D`, and `Upsample2D` are also supported when `conv_dim` is specified. No default `exclude_patterns`.
|
||||
- **Anima**: Targets `Block`, `PatchEmbed`, `TimestepEmbedding`, and `FinalLayer` for DiT, and `Qwen3Attention`/`Qwen3MLP` for the text encoder. Default `exclude_patterns` automatically skips modulation, normalization, embedder, and final_layer modules.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoHaとLoKrは、モデルのアーキテクチャを自動で検出し、適切なデフォルト設定を適用します。現在、以下のアーキテクチャに対応しています:
|
||||
|
||||
- **SDXL**: UNetの`Transformer2DModel`、テキストエンコーダの`CLIPAttention`/`CLIPMLP`を対象とします。`conv_dim`を指定した場合、`ResnetBlock2D`、`Downsample2D`、`Upsample2D`のConv2d層も対象になります。デフォルトの`exclude_patterns`はありません。
|
||||
- **Anima**: DiTの`Block`、`PatchEmbed`、`TimestepEmbedding`、`FinalLayer`、テキストエンコーダの`Qwen3Attention`/`Qwen3MLP`を対象とします。デフォルトの`exclude_patterns`により、modulation、normalization、embedder、final_layerモジュールは自動的にスキップされます。
|
||||
|
||||
</details>
|
||||
|
||||
## Training / 学習
|
||||
|
||||
To use LoHa or LoKr, change the `--network_module` argument in your training command. All other training options (dataset config, optimizer, etc.) remain the same as LoRA.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoHaまたはLoKrを使用するには、学習コマンドの `--network_module` 引数を変更します。その他の学習オプション(データセット設定、オプティマイザなど)はLoRAと同じです。
|
||||
|
||||
</details>
|
||||
|
||||
### LoHa (SDXL)
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
|
||||
--pretrained_model_name_or_path path/to/sdxl.safetensors \
|
||||
--dataset_config path/to/toml \
|
||||
--mixed_precision bf16 --fp8_base \
|
||||
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
|
||||
--network_module networks.loha --network_dim 32 --network_alpha 16 \
|
||||
--max_train_epochs 16 --save_every_n_epochs 1 \
|
||||
--output_dir path/to/output --output_name my-loha
|
||||
```
|
||||
|
||||
### LoKr (SDXL)
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
|
||||
--pretrained_model_name_or_path path/to/sdxl.safetensors \
|
||||
--dataset_config path/to/toml \
|
||||
--mixed_precision bf16 --fp8_base \
|
||||
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
|
||||
--network_module networks.lokr --network_dim 32 --network_alpha 16 \
|
||||
--max_train_epochs 16 --save_every_n_epochs 1 \
|
||||
--output_dir path/to/output --output_name my-lokr
|
||||
```
|
||||
|
||||
For Anima, replace `sdxl_train_network.py` with `anima_train_network.py` and use the appropriate model path and options.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaの場合は、`sdxl_train_network.py` を `anima_train_network.py` に置き換え、適切なモデルパスとオプションを使用してください。
|
||||
|
||||
</details>
|
||||
|
||||
### Common training options / 共通の学習オプション
|
||||
|
||||
The following `--network_args` options are available for both LoHa and LoKr, same as LoRA:
|
||||
|
||||
| Option | Description |
|
||||
|---|---|
|
||||
| `verbose=True` | Display detailed information about the network modules |
|
||||
| `rank_dropout=0.1` | Apply dropout to the rank dimension during training |
|
||||
| `module_dropout=0.1` | Randomly skip entire modules during training |
|
||||
| `exclude_patterns=[r'...']` | Exclude modules matching the regex patterns (in addition to architecture defaults) |
|
||||
| `include_patterns=[r'...']` | Override excludes: modules matching these regex patterns will be included even if they match `exclude_patterns` |
|
||||
| `network_reg_lrs=regex1=lr1,regex2=lr2` | Set per-module learning rates using regex patterns |
|
||||
| `network_reg_dims=regex1=dim1,regex2=dim2` | Set per-module dimensions (rank) using regex patterns |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
以下の `--network_args` オプションは、LoRAと同様にLoHaとLoKrの両方で使用できます:
|
||||
|
||||
| オプション | 説明 |
|
||||
|---|---|
|
||||
| `verbose=True` | ネットワークモジュールの詳細情報を表示 |
|
||||
| `rank_dropout=0.1` | 学習時にランク次元にドロップアウトを適用 |
|
||||
| `module_dropout=0.1` | 学習時にモジュール全体をランダムにスキップ |
|
||||
| `exclude_patterns=[r'...']` | 正規表現パターンに一致するモジュールを除外(アーキテクチャのデフォルトに追加) |
|
||||
| `include_patterns=[r'...']` | 正規表現パターンに一致するモジュールのみを対象とする |
|
||||
| `network_reg_lrs=regex1=lr1,regex2=lr2` | 正規表現パターンでモジュールごとの学習率を設定 |
|
||||
| `network_reg_dims=regex1=dim1,regex2=dim2` | 正規表現パターンでモジュールごとの次元(ランク)を設定 |
|
||||
|
||||
</details>
|
||||
|
||||
### Conv2d support / Conv2dサポート
|
||||
|
||||
By default, LoHa and LoKr target Linear and Conv2d 1x1 layers. To also train Conv2d 3x3+ layers (e.g., in SDXL's ResNet blocks), use the `conv_dim` and `conv_alpha` options:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8"
|
||||
```
|
||||
|
||||
For Conv2d 3x3+ layers, you can enable Tucker decomposition for more efficient parameter representation:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
|
||||
```
|
||||
|
||||
- Without `use_tucker`: The kernel dimensions are flattened into the input dimension (flat mode).
|
||||
- With `use_tucker=True`: A separate Tucker tensor is used to handle the kernel dimensions, which can be more parameter-efficient.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
デフォルトでは、LoHaとLoKrはLinearおよびConv2d 1x1層を対象とします。Conv2d 3x3+層(SDXLのResNetブロックなど)も学習するには、`conv_dim`と`conv_alpha`オプションを使用します:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8"
|
||||
```
|
||||
|
||||
Conv2d 3x3+層に対して、Tucker分解を有効にすることで、より効率的なパラメータ表現が可能です:
|
||||
|
||||
```bash
|
||||
--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
|
||||
```
|
||||
|
||||
- `use_tucker`なし: カーネル次元が入力次元に平坦化されます(flatモード)。
|
||||
- `use_tucker=True`: カーネル次元を扱う別のTuckerテンソルが使用され、よりパラメータ効率が良くなる場合があります。
|
||||
|
||||
</details>
|
||||
|
||||
### LoKr-specific option: `factor` / LoKr固有のオプション: `factor`
|
||||
|
||||
LoKr decomposes weight dimensions using factorization. The `factor` option controls how dimensions are split:
|
||||
|
||||
- `factor=-1` (default): Automatically find balanced factors. For example, dimension 512 is split into (16, 32).
|
||||
- `factor=N` (positive integer): Force factorization using the specified value. For example, `factor=4` splits dimension 512 into (4, 128).
|
||||
|
||||
```bash
|
||||
--network_args "factor=4"
|
||||
```
|
||||
|
||||
When `network_dim` (rank) is large enough relative to the factorized dimensions, LoKr uses a full matrix instead of a low-rank decomposition for the second factor. A warning will be logged in this case.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoKrは重みの次元を因数分解して分割します。`factor` オプションでその分割方法を制御します:
|
||||
|
||||
- `factor=-1`(デフォルト): バランスの良い因数を自動的に見つけます。例えば、次元512は(16, 32)に分割されます。
|
||||
- `factor=N`(正の整数): 指定した値で因数分解します。例えば、`factor=4` は次元512を(4, 128)に分割します。
|
||||
|
||||
```bash
|
||||
--network_args "factor=4"
|
||||
```
|
||||
|
||||
`network_dim`(ランク)が因数分解された次元に対して十分に大きい場合、LoKrは第2因子に低ランク分解ではなくフル行列を使用します。その場合、警告がログに出力されます。
|
||||
|
||||
</details>
|
||||
|
||||
### Anima-specific option: `train_llm_adapter` / Anima固有のオプション: `train_llm_adapter`
|
||||
|
||||
For Anima, you can additionally train the LLM adapter modules by specifying:
|
||||
|
||||
```bash
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
This includes `LLMAdapterTransformerBlock` modules as training targets.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaでは、以下を指定することでLLMアダプターモジュールも追加で学習できます:
|
||||
|
||||
```bash
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
これにより、`LLMAdapterTransformerBlock` モジュールが学習対象に含まれます。
|
||||
|
||||
</details>
|
||||
|
||||
### LoRA+ / LoRA+
|
||||
|
||||
LoRA+ (`loraplus_lr_ratio` etc. in `--network_args`) is supported with LoHa/LoKr. For LoHa, the second pair of matrices (`hada_w2_a`) is treated as the "plus" (higher learning rate) parameter group. For LoKr, the scale factor (`lokr_w1`) is treated as the "plus" parameter group.
|
||||
|
||||
```bash
|
||||
--network_args "loraplus_lr_ratio=4"
|
||||
```
|
||||
|
||||
This feature has been confirmed to work in basic testing, but feedback is welcome. If you encounter any issues, please report them.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LoRA+(`--network_args` の `loraplus_lr_ratio` 等)はLoHa/LoKrでもサポートされています。LoHaでは第2ペアの行列(`hada_w2_a`)が「plus」(より高い学習率)パラメータグループとして扱われます。LoKrではスケール係数(`lokr_w1`)が「plus」パラメータグループとして扱われます。
|
||||
|
||||
```bash
|
||||
--network_args "loraplus_lr_ratio=4"
|
||||
```
|
||||
|
||||
この機能は基本的なテストでは動作確認されていますが、フィードバックをお待ちしています。問題が発生した場合はご報告ください。
|
||||
|
||||
</details>
|
||||
|
||||
## How LoHa and LoKr work / LoHaとLoKrの仕組み
|
||||
|
||||
### LoHa
|
||||
|
||||
LoHa represents the weight update as a Hadamard (element-wise) product of two low-rank matrices:
|
||||
|
||||
```
|
||||
ΔW = (W1a × W1b) ⊙ (W2a × W2b)
|
||||
```
|
||||
|
||||
where `W1a`, `W1b`, `W2a`, `W2b` are low-rank matrices with rank `network_dim`. This means LoHa has roughly **twice the number of trainable parameters** compared to LoRA at the same rank, but can capture more complex weight structures due to the element-wise product.
|
||||
|
||||
For Conv2d 3x3+ layers with Tucker decomposition, each pair additionally has a Tucker tensor `T` and the reconstruction becomes: `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)`.
|
||||
|
||||
### LoKr
|
||||
|
||||
LoKr represents the weight update using a Kronecker product:
|
||||
|
||||
```
|
||||
ΔW = W1 ⊗ W2 (where W2 = W2a × W2b in low-rank mode)
|
||||
```
|
||||
|
||||
The original weight dimensions are factorized (e.g., a 512×512 weight might be split so that W1 is 16×16 and W2 is 32×32). W1 is always a full matrix (small), while W2 can be either low-rank decomposed or a full matrix depending on the rank setting. LoKr tends to produce **smaller models** compared to LoRA at the same rank.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
### LoHa
|
||||
|
||||
LoHaは重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します:
|
||||
|
||||
```
|
||||
ΔW = (W1a × W1b) ⊙ (W2a × W2b)
|
||||
```
|
||||
|
||||
ここで `W1a`, `W1b`, `W2a`, `W2b` はランク `network_dim` の低ランク行列です。LoHaは同じランクのLoRAと比較して学習可能なパラメータ数が **約2倍** になりますが、要素ごとの積により、より複雑な重み構造を捉えることができます。
|
||||
|
||||
Conv2d 3x3+層でTucker分解を使用する場合、各ペアにはさらにTuckerテンソル `T` があり、再構成は `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)` となります。
|
||||
|
||||
### LoKr
|
||||
|
||||
LoKrはKronecker積を使って重みの更新を表現します:
|
||||
|
||||
```
|
||||
ΔW = W1 ⊗ W2 (低ランクモードでは W2 = W2a × W2b)
|
||||
```
|
||||
|
||||
元の重みの次元が因数分解されます(例: 512×512の重みが、W1が16×16、W2が32×32に分割されます)。W1は常にフル行列(小さい)で、W2はランク設定に応じて低ランク分解またはフル行列になります。LoKrは同じランクのLoRAと比較して **より小さいモデル** を生成する傾向があります。
|
||||
|
||||
</details>
|
||||
|
||||
## Inference / 推論
|
||||
|
||||
Trained LoHa/LoKr weights are saved in safetensors format, just like LoRA.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習済みのLoHa/LoKrの重みは、LoRAと同様にsafetensors形式で保存されます。
|
||||
|
||||
</details>
|
||||
|
||||
### SDXL
|
||||
|
||||
For SDXL, use `gen_img.py` with `--network_module` and `--network_weights`, the same way as LoRA:
|
||||
|
||||
```bash
|
||||
python gen_img.py --ckpt path/to/sdxl.safetensors \
|
||||
--network_module networks.loha --network_weights path/to/loha.safetensors \
|
||||
--prompt "your prompt" ...
|
||||
```
|
||||
|
||||
Replace `networks.loha` with `networks.lokr` for LoKr weights.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
SDXLでは、LoRAと同様に `gen_img.py` で `--network_module` と `--network_weights` を指定します:
|
||||
|
||||
```bash
|
||||
python gen_img.py --ckpt path/to/sdxl.safetensors \
|
||||
--network_module networks.loha --network_weights path/to/loha.safetensors \
|
||||
--prompt "your prompt" ...
|
||||
```
|
||||
|
||||
LoKrの重みを使用する場合は `networks.loha` を `networks.lokr` に置き換えてください。
|
||||
|
||||
</details>
|
||||
|
||||
### Anima
|
||||
|
||||
For Anima, use `anima_minimal_inference.py` with the `--lora_weight` argument. LoRA, LoHa, and LoKr weights are automatically detected and merged:
|
||||
|
||||
```bash
|
||||
python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
|
||||
--lora_weight path/to/loha_or_lokr.safetensors ...
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Animaでは、`anima_minimal_inference.py` に `--lora_weight` 引数を指定します。LoRA、LoHa、LoKrの重みは自動的に判定されてマージされます:
|
||||
|
||||
```bash
|
||||
python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
|
||||
--lora_weight path/to/loha_or_lokr.safetensors ...
|
||||
```
|
||||
|
||||
</details>
|
||||
736
docs/train_leco.md
Normal file
736
docs/train_leco.md
Normal file
@@ -0,0 +1,736 @@
|
||||
# LECO Training Guide / LECO 学習ガイド
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only.
|
||||
|
||||
This repository provides two LECO training scripts:
|
||||
|
||||
- `train_leco.py` for Stable Diffusion 1.x / 2.x
|
||||
- `sdxl_train_leco.py` for SDXL
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。
|
||||
|
||||
このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています:
|
||||
|
||||
- `train_leco.py` : Stable Diffusion 1.x / 2.x 用
|
||||
- `sdxl_train_leco.py` : SDXL 用
|
||||
</details>
|
||||
|
||||
## 1. Overview / 概要
|
||||
|
||||
### What LECO Can Do / LECO でできること
|
||||
|
||||
LECO can be used for:
|
||||
|
||||
- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images)
|
||||
- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced)
|
||||
- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair")
|
||||
|
||||
Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO は以下の用途に使用できます:
|
||||
|
||||
- **概念の消去**: 特定のスタイルや概念を除去する(例:生成画像から「van gogh」スタイルを消去)
|
||||
- **概念の強化**: 特定の属性を強化する(例:「detailed」をより顕著にする)
|
||||
- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する(例:「short hair」と「long hair」の間のスライダー)
|
||||
|
||||
通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。
|
||||
</details>
|
||||
|
||||
### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い
|
||||
|
||||
| | Standard LoRA | LECO |
|
||||
|---|---|---|
|
||||
| Training data | Image dataset required | **No images needed** |
|
||||
| Configuration | Dataset TOML | Prompt TOML |
|
||||
| Training target | U-Net and/or Text Encoder | **U-Net only** |
|
||||
| Training unit | Epochs and steps | **Steps only** |
|
||||
| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| | 通常の LoRA | LECO |
|
||||
|---|---|---|
|
||||
| 学習データ | 画像データセットが必要 | **画像不要** |
|
||||
| 設定ファイル | データセット TOML | プロンプト TOML |
|
||||
| 学習対象 | U-Net と Text Encoder | **U-Net のみ** |
|
||||
| 学習単位 | エポックとステップ | **ステップのみ** |
|
||||
| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) |
|
||||
</details>
|
||||
|
||||
## 2. Prompt Configuration File / プロンプト設定ファイル
|
||||
|
||||
LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**(ai-toolkit スタイル)の2つの形式に対応しています。
|
||||
</details>
|
||||
|
||||
### 2.1. Original LECO Format / オリジナル LECO 形式
|
||||
|
||||
Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair.
|
||||
|
||||
```toml
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
positive = "van gogh"
|
||||
unconditional = ""
|
||||
neutral = ""
|
||||
action = "erase"
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
batch_size = 1
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Each `[[prompts]]` entry defines one training pair with the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target` | Yes | - | The concept to be modified by the LoRA |
|
||||
| `positive` | No | same as `target` | The "positive direction" prompt for building the training target |
|
||||
| `unconditional` | No | `""` | The unconditional/negative prompt |
|
||||
| `neutral` | No | `""` | The neutral baseline prompt |
|
||||
| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it |
|
||||
| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) |
|
||||
| `resolution` | No | `512` | Training resolution (int or `[height, width]`) |
|
||||
| `batch_size` | No | `1` | Number of latent samples per training step for this prompt |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier during training |
|
||||
| `weight` | No | `1.0` | Loss weight for this prompt pair |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。
|
||||
|
||||
各 `[[prompts]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target` | はい | - | LoRA によって変更される概念 |
|
||||
| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト |
|
||||
| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト |
|
||||
| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト |
|
||||
| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 |
|
||||
| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) |
|
||||
| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]`) |
|
||||
| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 |
|
||||
| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み |
|
||||
</details>
|
||||
|
||||
### 2.2. Slider Target Format / スライダーターゲット形式
|
||||
|
||||
Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided).
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 1024
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = "1girl"
|
||||
positive = "1girl, long hair"
|
||||
negative = "1girl, short hair"
|
||||
multiplier = 1.0
|
||||
weight = 1.0
|
||||
```
|
||||
|
||||
Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets.
|
||||
|
||||
Each `[[targets]]` entry supports the following fields:
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
|-------|----------|---------|-------------|
|
||||
| `target_class` | Yes | - | The base class/subject prompt |
|
||||
| `positive` | No* | `""` | Prompt for the positive direction of the slider |
|
||||
| `negative` | No* | `""` | Prompt for the negative direction of the slider |
|
||||
| `multiplier` | No | `1.0` | LoRA strength multiplier |
|
||||
| `weight` | No | `1.0` | Loss weight |
|
||||
|
||||
\* At least one of `positive` or `negative` must be provided.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive` と `negative` の両方がある場合は4ペア、片方のみの場合は2ペア)。
|
||||
|
||||
トップレベルのフィールド(`guidance_scale`、`resolution`、`neutral`、`batch_size` など)は全ターゲットのデフォルト値として機能します。
|
||||
|
||||
各 `[[targets]]` エントリのフィールド:
|
||||
|
||||
| フィールド | 必須 | デフォルト | 説明 |
|
||||
|-----------|------|-----------|------|
|
||||
| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト |
|
||||
| `positive` | いいえ* | `""` | スライダーの正方向プロンプト |
|
||||
| `negative` | いいえ* | `""` | スライダーの負方向プロンプト |
|
||||
| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 |
|
||||
| `weight` | いいえ | `1.0` | loss 重み |
|
||||
|
||||
\* `positive` と `negative` のうち少なくとも一方を指定する必要があります。
|
||||
</details>
|
||||
|
||||
### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト
|
||||
|
||||
You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization.
|
||||
|
||||
```toml
|
||||
guidance_scale = 1.5
|
||||
resolution = 1024
|
||||
neutrals = ["", "photo of a person", "cinematic portrait"]
|
||||
|
||||
[[targets]]
|
||||
target_class = "person"
|
||||
positive = "smiling person"
|
||||
negative = "expressionless person"
|
||||
```
|
||||
|
||||
You can also load neutral prompts from a text file (one prompt per line):
|
||||
|
||||
```toml
|
||||
neutral_prompt_file = "neutrals.txt"
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。
|
||||
|
||||
ニュートラルプロンプトをテキストファイル(1行1プロンプト)から読み込むこともできます。
|
||||
</details>
|
||||
|
||||
### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換
|
||||
|
||||
If you have an existing ai-toolkit style YAML config, convert it to TOML as follows:
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。
|
||||
</details>
|
||||
|
||||
**YAML:**
|
||||
```yaml
|
||||
targets:
|
||||
- target_class: ""
|
||||
positive: "high detail"
|
||||
negative: "low detail"
|
||||
multiplier: 1.0
|
||||
guidance_scale: 1.0
|
||||
resolution: 512
|
||||
```
|
||||
|
||||
**TOML:**
|
||||
```toml
|
||||
guidance_scale = 1.0
|
||||
resolution = 512
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.0
|
||||
```
|
||||
|
||||
Key syntax differences:
|
||||
|
||||
- Use `=` instead of `:` for key-value pairs
|
||||
- Use `[[targets]]` header instead of `targets:` with `- ` list items
|
||||
- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`)
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
主な構文の違い:
|
||||
|
||||
- キーと値の区切りに `:` ではなく `=` を使用
|
||||
- `targets:` と `- ` のリスト記法ではなく `[[targets]]` ヘッダを使用
|
||||
- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`)
|
||||
</details>
|
||||
|
||||
## 3. Running the Training / 学習の実行
|
||||
|
||||
Training is started by executing the script from the terminal. Below are basic command-line examples.
|
||||
|
||||
In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。
|
||||
|
||||
実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。
|
||||
</details>
|
||||
|
||||
### SD 1.x / 2.x
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 train_leco.py
|
||||
--pretrained_model_name_or_path="model.safetensors"
|
||||
--prompts_file="prompts.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_leco"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=500
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=100
|
||||
```
|
||||
|
||||
### SDXL
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 sdxl_train_leco.py
|
||||
--pretrained_model_name_or_path="sdxl_model.safetensors"
|
||||
--prompts_file="slider.toml"
|
||||
--output_dir="output"
|
||||
--output_name="my_sdxl_slider"
|
||||
--network_dim=8
|
||||
--network_alpha=4
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--max_train_steps=1000
|
||||
--max_denoising_steps=40
|
||||
--mixed_precision=bf16
|
||||
--sdpa
|
||||
--gradient_checkpointing
|
||||
--save_every_n_steps=200
|
||||
```
|
||||
|
||||
## 4. Command-Line Arguments / コマンドライン引数
|
||||
|
||||
### 4.1. LECO-Specific Arguments / LECO 固有の引数
|
||||
|
||||
These arguments are unique to LECO and not found in standard LoRA training scripts.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。
|
||||
</details>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[Required]**
|
||||
* Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format.
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`.
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--prompts_file="prompts.toml"` **[必須]**
|
||||
* LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。
|
||||
|
||||
* `--max_denoising_steps=40`
|
||||
* 各学習イテレーションでの部分デノイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデノイズが行われます。デフォルト: `40`。
|
||||
|
||||
* `--leco_denoise_guidance_scale=3.0`
|
||||
* 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`。
|
||||
</details>
|
||||
|
||||
#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い
|
||||
|
||||
There are two separate guidance scale parameters that control different aspects of LECO training:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal.
|
||||
|
||||
2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target.
|
||||
|
||||
If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります:
|
||||
|
||||
1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。
|
||||
|
||||
2. **`guidance_scale`(TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。
|
||||
|
||||
学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。
|
||||
</details>
|
||||
|
||||
### 4.2. Model Arguments / モデル引数
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[Required]**
|
||||
* Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID).
|
||||
|
||||
* `--v2` (SD 1.x/2.x only)
|
||||
* Specify when using a Stable Diffusion v2.x model.
|
||||
|
||||
* `--v_parameterization` (SD 1.x/2.x only)
|
||||
* Specify when using a v-prediction model (e.g., SD 2.x 768px models).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--pretrained_model_name_or_path="model.safetensors"` **[必須]**
|
||||
* ベースとなる Stable Diffusion モデルのパス(`.ckpt`、`.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID)。
|
||||
|
||||
* `--v2`(SD 1.x/2.x のみ)
|
||||
* Stable Diffusion v2.x モデルを使用する場合に指定します。
|
||||
|
||||
* `--v_parameterization`(SD 1.x/2.x のみ)
|
||||
* v-prediction モデル(SD 2.x 768px モデルなど)を使用する場合に指定します。
|
||||
</details>
|
||||
|
||||
### 4.3. LoRA Network Arguments / LoRA ネットワーク引数
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* Network module to train. Default: `networks.lora`.
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`.
|
||||
|
||||
* `--network_alpha=4`
|
||||
* LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`.
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* Dropout rate for LoRA layers. Optional.
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA.
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* Load pretrained LoRA weights to continue training.
|
||||
|
||||
* `--dim_from_weights`
|
||||
* Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--network_module=networks.lora`
|
||||
* 学習するネットワークモジュール。デフォルト: `networks.lora`。
|
||||
|
||||
* `--network_dim=8`
|
||||
* LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`。
|
||||
|
||||
* `--network_alpha=4`
|
||||
* 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`。
|
||||
|
||||
* `--network_dropout=0.1`
|
||||
* LoRA レイヤーのドロップアウト率。省略可。
|
||||
|
||||
* `--network_args "key=value" ...`
|
||||
* ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。
|
||||
|
||||
* `--network_weights="path/to/weights.safetensors"`
|
||||
* 事前学習済み LoRA ウェイトを読み込んで学習を続行します。
|
||||
|
||||
* `--dim_from_weights`
|
||||
* `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。
|
||||
</details>
|
||||
|
||||
### 4.4. Training Parameters / 学習パラメータ
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`.
|
||||
* Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only).
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* Learning rate. Typical range for LECO: `1e-4` to `1e-3`.
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used.
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc.
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc.
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* Number of warmup steps for the learning rate scheduler.
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* Number of steps to accumulate gradients before updating. Effectively multiplies the batch size.
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* Maximum gradient norm for gradient clipping. Set to `0` to disable.
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--max_train_steps=500`
|
||||
* 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`。
|
||||
* 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。
|
||||
|
||||
* `--learning_rate=1e-4`
|
||||
* 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`。
|
||||
|
||||
* `--unet_lr=1e-4`
|
||||
* U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。
|
||||
|
||||
* `--optimizer_type="AdamW8bit"`
|
||||
* オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW`、`Lion`、`Adafactor` 等が選択可能です。
|
||||
|
||||
* `--lr_scheduler="constant"`
|
||||
* 学習率スケジューラ。`constant`、`cosine`、`linear`、`constant_with_warmup` 等が選択可能です。
|
||||
|
||||
* `--lr_warmup_steps=0`
|
||||
* 学習率スケジューラのウォームアップステップ数。
|
||||
|
||||
* `--gradient_accumulation_steps=1`
|
||||
* 勾配を累積するステップ数。実質的にバッチサイズを増加させます。
|
||||
|
||||
* `--max_grad_norm=1.0`
|
||||
* 勾配クリッピングの最大勾配ノルム。`0` で無効化。
|
||||
|
||||
* `--min_snr_gamma=5.0`
|
||||
* Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。
|
||||
</details>
|
||||
|
||||
### 4.5. Output and Save Arguments / 出力・保存引数
|
||||
|
||||
* `--output_dir="output"` **[Required]**
|
||||
* Directory for saving trained LoRA models and logs.
|
||||
|
||||
* `--output_name="my_leco"` **[Required]**
|
||||
* Base filename for the trained LoRA (without extension).
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`.
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* Save an intermediate checkpoint every N steps. If not specified, only the final model is saved.
|
||||
* Note: `--save_every_n_epochs` is **not supported** for LECO.
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used.
|
||||
|
||||
* `--no_metadata`
|
||||
* Do not write metadata into the saved model file.
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* A comment string stored in the model metadata.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--output_dir="output"` **[必須]**
|
||||
* 学習済み LoRA モデルとログの保存先ディレクトリ。
|
||||
|
||||
* `--output_name="my_leco"` **[必須]**
|
||||
* 学習済み LoRA のベースファイル名(拡張子なし)。
|
||||
|
||||
* `--save_model_as="safetensors"`
|
||||
* モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt`、`pt` から選択。
|
||||
|
||||
* `--save_every_n_steps=100`
|
||||
* N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。
|
||||
* 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。
|
||||
|
||||
* `--save_precision="fp16"`
|
||||
* モデル保存時の精度。`float`、`fp16`、`bf16` から選択。省略時は学習時の精度が使用されます。
|
||||
|
||||
* `--no_metadata`
|
||||
* 保存するモデルファイルにメタデータを書き込みません。
|
||||
|
||||
* `--training_comment="my comment"`
|
||||
* モデルのメタデータに保存されるコメント文字列。
|
||||
</details>
|
||||
|
||||
### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended.
|
||||
|
||||
* `--full_fp16`
|
||||
* Train entirely in fp16 precision including gradients.
|
||||
|
||||
* `--full_bf16`
|
||||
* Train entirely in bf16 precision including gradients.
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions.
|
||||
|
||||
* `--sdpa`
|
||||
* Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended.
|
||||
|
||||
* `--xformers`
|
||||
* Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`.
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* Use memory-efficient attention implementation. Another alternative to `--sdpa`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--mixed_precision="bf16"`
|
||||
* 混合精度学習。`no`、`fp16`、`bf16` から選択。`bf16` または `fp16` の使用を推奨します。
|
||||
|
||||
* `--full_fp16`
|
||||
* 勾配を含め全体を fp16 精度で学習します。
|
||||
|
||||
* `--full_bf16`
|
||||
* 勾配を含め全体を bf16 精度で学習します。
|
||||
|
||||
* `--gradient_checkpointing`
|
||||
* gradient checkpointing を有効にしてVRAM使用量を削減します(学習速度は若干低下)。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。
|
||||
|
||||
* `--sdpa`
|
||||
* Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。
|
||||
|
||||
* `--xformers`
|
||||
* xformers を使用したメモリ効率の良い attention(`xformers` パッケージが必要)。`--sdpa` の代替。
|
||||
|
||||
* `--mem_eff_attn`
|
||||
* メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。
|
||||
</details>
|
||||
|
||||
### 4.7. Other Useful Arguments / その他の便利な引数
|
||||
|
||||
* `--seed=42`
|
||||
* Random seed for reproducibility. If not specified, a random seed is automatically generated.
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* Enable noise offset. Small values like `0.02` to `0.1` can help with training stability.
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* Fix noise scheduler betas to enforce zero terminal SNR.
|
||||
|
||||
* `--clip_skip=2` (SD 1.x/2.x only)
|
||||
* Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`.
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* Directory for TensorBoard logs. Enables logging when specified.
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* Logging tool. Options: `tensorboard`, `wandb`, `all`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
* `--seed=42`
|
||||
* 再現性のための乱数シード。指定しない場合は自動生成されます。
|
||||
|
||||
* `--noise_offset=0.05`
|
||||
* ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。
|
||||
|
||||
* `--zero_terminal_snr`
|
||||
* noise scheduler の betas を修正してゼロ終端 SNR を強制します。
|
||||
|
||||
* `--clip_skip=2`(SD 1.x/2.x のみ)
|
||||
* text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`。
|
||||
|
||||
* `--logging_dir="logs"`
|
||||
* TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。
|
||||
|
||||
* `--log_with="tensorboard"`
|
||||
* ログツール。`tensorboard`、`wandb`、`all` から選択。
|
||||
</details>
|
||||
|
||||
## 5. Tips / ヒント
|
||||
|
||||
### Tuning the Effect Strength / 効果の強さの調整
|
||||
|
||||
If the trained LoRA has a weak or unnoticeable effect:
|
||||
|
||||
1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect.
|
||||
2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`).
|
||||
3. **Increase `--max_denoising_steps`** for more refined intermediate latents.
|
||||
4. **Increase `--max_train_steps`** to train longer.
|
||||
5. **Apply the LoRA with a higher weight** at inference time.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習した LoRA の効果が弱い、または認識できない場合:
|
||||
|
||||
1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。
|
||||
2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。
|
||||
3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。
|
||||
4. **`--max_train_steps` を増やして**、より長く学習する。
|
||||
5. **推論時に LoRA のウェイトを大きくして**適用する。
|
||||
</details>
|
||||
|
||||
### Recommended Starting Settings / 推奨の開始設定
|
||||
|
||||
| Parameter | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution` (in TOML) | `512` | `1024` |
|
||||
| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size` (in TOML) | `1`-`4` | `1`-`4` |
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
| パラメータ | SD 1.x/2.x | SDXL |
|
||||
|-----------|-------------|------|
|
||||
| `--network_dim` | `4`-`8` | `8`-`16` |
|
||||
| `--learning_rate` | `1e-4` | `1e-4` |
|
||||
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
|
||||
| `resolution`(TOML内) | `512` | `1024` |
|
||||
| `guidance_scale`(TOML内) | `1.0`-`2.0` | `1.0`-`3.0` |
|
||||
| `batch_size`(TOML内) | `1`-`4` | `1`-`4` |
|
||||
</details>
|
||||
|
||||
### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップ(SDXL)
|
||||
|
||||
For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file:
|
||||
|
||||
```toml
|
||||
resolution = 1024
|
||||
dynamic_resolution = true
|
||||
dynamic_crops = true
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
```
|
||||
|
||||
- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets.
|
||||
- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings.
|
||||
|
||||
These options can improve the LoRA's generalization across different aspect ratios.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。
|
||||
|
||||
- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。
|
||||
- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。
|
||||
|
||||
これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc.
|
||||
|
||||
For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。
|
||||
|
||||
スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。
|
||||
</details>
|
||||
@@ -404,7 +404,7 @@ def main(args):
|
||||
rating_tag = None
|
||||
quality_max_prob = -1
|
||||
quality_tag = None
|
||||
character_tags = []
|
||||
img_character_tags = []
|
||||
|
||||
min_thres = min(
|
||||
args.thresh,
|
||||
@@ -449,7 +449,7 @@ def main(args):
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # we separate character tags
|
||||
character_tags.append((tag_name, p))
|
||||
img_character_tags.append((tag_name, p))
|
||||
else:
|
||||
combined_tags.append((tag_name, p))
|
||||
elif (
|
||||
@@ -464,9 +464,9 @@ def main(args):
|
||||
|
||||
# sort by probability
|
||||
combined_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
if character_tags:
|
||||
character_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
combined_tags = character_tags + combined_tags
|
||||
if img_character_tags:
|
||||
img_character_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
combined_tags = img_character_tags + combined_tags
|
||||
combined_tags = [t[0] for t in combined_tags] # remove probability
|
||||
|
||||
if quality_tag is not None:
|
||||
|
||||
1671
library/anima_models.py
Normal file
1671
library/anima_models.py
Normal file
File diff suppressed because it is too large
Load Diff
615
library/anima_train_utils.py
Normal file
615
library/anima_train_utils.py
Normal file
@@ -0,0 +1,615 @@
|
||||
# Anima Training Utilities
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
|
||||
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Anima-specific training arguments
|
||||
|
||||
|
||||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add Anima-specific training arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--qwen3",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_adapter_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--self_attn_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cross_attn_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlp_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mod_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_tokenizer_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen3_max_token_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum token length for Qwen3 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_max_token_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum token length for T5 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
type=str,
|
||||
default="sigmoid",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
help="Timestep sampling method (default: sigmoid (logit normal))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attn_mode",
|
||||
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||||
default=None,
|
||||
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||||
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split_attn",
|
||||
action="store_true",
|
||||
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_chunk_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
|
||||
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_disable_cache",
|
||||
action="store_true",
|
||||
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
|
||||
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
|
||||
)
|
||||
|
||||
|
||||
# Loss weighting
|
||||
|
||||
|
||||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss weighting for Anima training.
|
||||
|
||||
Same schemes as SD3 but can add Anima-specific ones if needed in future.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
elif weighting_scheme == "none" or weighting_scheme is None:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
# Parameter groups (6 groups with separate LRs)
|
||||
def get_anima_param_groups(
|
||||
dit,
|
||||
base_lr: float,
|
||||
self_attn_lr: Optional[float] = None,
|
||||
cross_attn_lr: Optional[float] = None,
|
||||
mlp_lr: Optional[float] = None,
|
||||
mod_lr: Optional[float] = None,
|
||||
llm_adapter_lr: Optional[float] = None,
|
||||
):
|
||||
"""Create parameter groups for Anima training with separate learning rates.
|
||||
|
||||
Args:
|
||||
dit: Anima model
|
||||
base_lr: Base learning rate
|
||||
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||||
cross_attn_lr: LR for cross-attention layers
|
||||
mlp_lr: LR for MLP layers
|
||||
mod_lr: LR for AdaLN modulation layers
|
||||
llm_adapter_lr: LR for LLM adapter
|
||||
|
||||
Returns:
|
||||
List of parameter group dicts for optimizer
|
||||
"""
|
||||
if self_attn_lr is None:
|
||||
self_attn_lr = base_lr
|
||||
if cross_attn_lr is None:
|
||||
cross_attn_lr = base_lr
|
||||
if mlp_lr is None:
|
||||
mlp_lr = base_lr
|
||||
if mod_lr is None:
|
||||
mod_lr = base_lr
|
||||
if llm_adapter_lr is None:
|
||||
llm_adapter_lr = base_lr
|
||||
|
||||
base_params = []
|
||||
self_attn_params = []
|
||||
cross_attn_params = []
|
||||
mlp_params = []
|
||||
mod_params = []
|
||||
llm_adapter_params = []
|
||||
|
||||
for name, p in dit.named_parameters():
|
||||
# Store original name for debugging
|
||||
p.original_name = name
|
||||
|
||||
if "llm_adapter" in name:
|
||||
llm_adapter_params.append(p)
|
||||
elif ".self_attn" in name:
|
||||
self_attn_params.append(p)
|
||||
elif ".cross_attn" in name:
|
||||
cross_attn_params.append(p)
|
||||
elif ".mlp" in name:
|
||||
mlp_params.append(p)
|
||||
elif ".adaln_modulation" in name:
|
||||
mod_params.append(p)
|
||||
else:
|
||||
base_params.append(p)
|
||||
|
||||
logger.info(f"Parameter groups:")
|
||||
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
|
||||
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
|
||||
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
|
||||
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
|
||||
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
|
||||
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
|
||||
|
||||
param_groups = []
|
||||
for lr, params, name in [
|
||||
(base_lr, base_params, "base"),
|
||||
(self_attn_lr, self_attn_params, "self_attn"),
|
||||
(cross_attn_lr, cross_attn_params, "cross_attn"),
|
||||
(mlp_lr, mlp_params, "mlp"),
|
||||
(mod_lr, mod_params, "mod"),
|
||||
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
|
||||
]:
|
||||
if lr == 0:
|
||||
for p in params:
|
||||
p.requires_grad_(False)
|
||||
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||||
elif len(params) > 0:
|
||||
param_groups.append({"params": params, "lr": lr})
|
||||
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
|
||||
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||||
|
||||
return param_groups
|
||||
|
||||
|
||||
# Save functions
|
||||
def save_anima_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at the end of training."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
# Save with 'net.' prefix for ComfyUI compatibility
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
|
||||
def save_anima_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator: Accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
True,
|
||||
True,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Sampling (Euler discrete for rectified flow)
|
||||
def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: Optional[int],
|
||||
dit: anima_models.Anima,
|
||||
crossattn_emb: torch.Tensor,
|
||||
steps: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
guidance_scale: float = 1.0,
|
||||
flow_shift: float = 3.0,
|
||||
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||||
|
||||
Args:
|
||||
height, width: Output image dimensions
|
||||
seed: Random seed (None for random)
|
||||
dit: Anima model
|
||||
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||||
steps: Number of sampling steps
|
||||
dtype: Compute dtype
|
||||
device: Compute device
|
||||
guidance_scale: CFG scale (1.0 = no guidance)
|
||||
flow_shift: Flow shift parameter for rectified flow
|
||||
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||||
|
||||
Returns:
|
||||
Denoised latents
|
||||
"""
|
||||
# Latent shape: (1, 16, 1, H/8, W/8) for single image
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
|
||||
|
||||
# Generate noise
|
||||
if seed is not None:
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
|
||||
|
||||
# Timestep schedule: linear from 1.0 to 0.0
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||
flow_shift = float(flow_shift)
|
||||
if flow_shift != 1.0:
|
||||
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
|
||||
|
||||
# Start from pure noise
|
||||
x = noise.clone()
|
||||
|
||||
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
|
||||
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
|
||||
|
||||
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
|
||||
|
||||
for i in tqdm(range(steps), desc="Sampling"):
|
||||
sigma = sigmas[i]
|
||||
t = sigma.unsqueeze(0) # (1,)
|
||||
|
||||
if use_cfg:
|
||||
# CFG: two separate passes to reduce memory usage
|
||||
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
pos_out = pos_out.float()
|
||||
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
|
||||
neg_out = neg_out.float()
|
||||
|
||||
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||||
else:
|
||||
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
model_output = model_output.float()
|
||||
|
||||
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
|
||||
dt = sigmas[i + 1] - sigma
|
||||
x = x + model_output * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit: anima_models.Anima,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs=None,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
"""Generate sample images during training.
|
||||
|
||||
This is a simplified sampler for Anima - it generates images using the current model state.
|
||||
"""
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None:
|
||||
return
|
||||
|
||||
logger.info(f"Generating sample images at step {steps}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
# Unwrap models
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
dit.switch_block_swap_for_inference()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
save_dir = os.path.join(args.output_dir, "sample")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Save RNG state
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
dit.prepare_block_swap_before_forward()
|
||||
_sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# Restore RNG state
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
dit.switch_block_swap_for_training()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def _sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
):
|
||||
"""Generate a single sample image."""
|
||||
prompt = prompt_dict.get("prompt", "")
|
||||
negative_prompt = prompt_dict.get("negative_prompt", "")
|
||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
flow_shift = prompt_dict.get("flow_shift", 3.0)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
|
||||
|
||||
height = max(64, height - height % 16)
|
||||
width = max(64, width - width % 16)
|
||||
|
||||
logger.info(
|
||||
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
|
||||
)
|
||||
|
||||
# Encode prompt
|
||||
def encode_prompt(prpt):
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
return sample_prompts_te_outputs[prpt]
|
||||
if text_encoder is not None:
|
||||
tokens = tokenize_strategy.tokenize(prpt)
|
||||
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||||
return encoded
|
||||
return None
|
||||
|
||||
encoded = encode_prompt(prompt)
|
||||
if encoded is None:
|
||||
logger.warning("Cannot encode prompt, skipping sample")
|
||||
return
|
||||
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
|
||||
|
||||
# Convert to tensors if numpy
|
||||
if isinstance(prompt_embeds, np.ndarray):
|
||||
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
|
||||
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
|
||||
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||||
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.use_llm_adapter:
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
)
|
||||
crossattn_emb[~t5_attn_mask.bool()] = 0
|
||||
else:
|
||||
crossattn_emb = prompt_embeds
|
||||
|
||||
# Encode negative prompt for CFG
|
||||
neg_crossattn_emb = None
|
||||
if scale > 1.0 and negative_prompt is not None:
|
||||
neg_encoded = encode_prompt(negative_prompt)
|
||||
if neg_encoded is not None:
|
||||
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
|
||||
if isinstance(neg_pe, np.ndarray):
|
||||
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
|
||||
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
|
||||
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||||
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||||
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
|
||||
neg_am = neg_am.to(accelerator.device)
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.use_llm_adapter:
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
target_attention_mask=neg_t5_am,
|
||||
source_attention_mask=neg_am,
|
||||
)
|
||||
neg_crossattn_emb[~neg_t5_am.bool()] = 0
|
||||
else:
|
||||
neg_crossattn_emb = neg_pe
|
||||
|
||||
# Generate sample
|
||||
clean_memory_on_device(accelerator.device)
|
||||
latents = do_sample(
|
||||
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
|
||||
)
|
||||
|
||||
# Decode latents
|
||||
gc.collect()
|
||||
synchronize_device(accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = vae.device
|
||||
vae.to(accelerator.device)
|
||||
decoded = vae.decode_to_pixels(latents)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Convert to image
|
||||
image = decoded.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
# Remove temporal dim if present
|
||||
if image.ndim == 4:
|
||||
image = image[:, 0, :, :]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(decoded_np)
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i = prompt_dict.get("enum", 0)
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# Log to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
import wandb
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|
||||
309
library/anima_utils.py
Normal file
309
library/anima_utils.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Anima model loading/saving utilities
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library import anima_models
|
||||
from library.safetensors_utils import WeightTransformHooks
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Original Anima high-precision keys. Kept for reference, but not used currently.
|
||||
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
|
||||
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
|
||||
|
||||
|
||||
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
|
||||
# ".embed." excludes Embedding in LLMAdapter
|
||||
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
|
||||
|
||||
|
||||
def load_anima_model(
|
||||
device: Union[str, torch.device],
|
||||
dit_path: str,
|
||||
attn_mode: str,
|
||||
split_attn: bool,
|
||||
loading_device: Union[str, torch.device],
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> anima_models.Anima:
|
||||
"""
|
||||
Load Anima model from the specified checkpoint.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): Device for optimization or merging
|
||||
dit_path (str): Path to the DiT model checkpoint.
|
||||
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
|
||||
split_attn (bool): Whether to use split attention.
|
||||
loading_device (Union[str, torch.device]): Device to load the model weights on.
|
||||
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
||||
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
||||
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
assert (
|
||||
not fp8_scaled and dit_weight_dtype is not None
|
||||
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
|
||||
|
||||
device = torch.device(device)
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
# We currently support fixed DiT config for Anima models
|
||||
dit_config = {
|
||||
"max_img_h": 512,
|
||||
"max_img_w": 512,
|
||||
"max_frames": 128,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"patch_spatial": 2,
|
||||
"patch_temporal": 1,
|
||||
"model_channels": 2048,
|
||||
"concat_padding_mask": True,
|
||||
"crossattn_emb_channels": 1024,
|
||||
"pos_emb_cls": "rope3d",
|
||||
"pos_emb_learnable": True,
|
||||
"pos_emb_interpolation": "crop",
|
||||
"min_fps": 1,
|
||||
"max_fps": 30,
|
||||
"use_adaln_lora": True,
|
||||
"adaln_lora_dim": 256,
|
||||
"num_blocks": 28,
|
||||
"num_heads": 16,
|
||||
"extra_per_block_abs_pos_emb": False,
|
||||
"rope_h_extrapolation_ratio": 4.0,
|
||||
"rope_w_extrapolation_ratio": 4.0,
|
||||
"rope_t_extrapolation_ratio": 1.0,
|
||||
"extra_h_extrapolation_ratio": 1.0,
|
||||
"extra_w_extrapolation_ratio": 1.0,
|
||||
"extra_t_extrapolation_ratio": 1.0,
|
||||
"rope_enable_fps_modulation": False,
|
||||
"use_llm_adapter": True,
|
||||
"attn_mode": attn_mode,
|
||||
"split_attn": split_attn,
|
||||
}
|
||||
with init_empty_weights():
|
||||
model = anima_models.Anima(**dit_config)
|
||||
if dit_weight_dtype is not None:
|
||||
model.to(dit_weight_dtype)
|
||||
|
||||
# load model weights with dynamic fp8 optimization and LoRA merging if needed
|
||||
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
|
||||
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
|
||||
sd = load_safetensors_with_lora_and_fp8(
|
||||
model_files=dit_path,
|
||||
lora_weights_list=lora_weights_list,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=fp8_scaled,
|
||||
calc_device=device,
|
||||
move_to_device=(loading_device == device),
|
||||
dit_weight_dtype=dit_weight_dtype,
|
||||
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
||||
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
||||
weight_transform_hooks=rename_hooks,
|
||||
)
|
||||
|
||||
if fp8_scaled:
|
||||
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
|
||||
|
||||
if loading_device.type != "cpu":
|
||||
# make sure all the model weights are on the loading_device
|
||||
logger.info(f"Moving weights to {loading_device}")
|
||||
for key in sd.keys():
|
||||
sd[key] = sd[key].to(loading_device)
|
||||
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
if missing:
|
||||
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
|
||||
]
|
||||
if unexpected_missing:
|
||||
# Raise error to avoid silent failures
|
||||
raise RuntimeError(
|
||||
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
|
||||
)
|
||||
missing = {} # all missing keys were expected
|
||||
if unexpected:
|
||||
# Raise error to avoid silent failures
|
||||
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
|
||||
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_qwen3_tokenizer(qwen3_path: str):
|
||||
"""Load Qwen3 tokenizer only (without the text encoder model).
|
||||
|
||||
Args:
|
||||
qwen3_path: Path to either a directory with model files or a safetensors file.
|
||||
If a directory, loads tokenizer from it directly.
|
||||
If a file, uses configs/qwen3_06b/ for tokenizer config.
|
||||
Returns:
|
||||
tokenizer
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if os.path.isdir(qwen3_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
||||
if not os.path.exists(config_dir):
|
||||
raise FileNotFoundError(
|
||||
f"Qwen3 config directory not found at {config_dir}. "
|
||||
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
||||
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_qwen3_text_encoder(
|
||||
qwen3_path: str,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[List[float]] = None,
|
||||
):
|
||||
"""Load Qwen3-0.6B text encoder.
|
||||
|
||||
Args:
|
||||
qwen3_path: Path to either a directory with model files or a safetensors file
|
||||
dtype: Model dtype
|
||||
device: Device to load to
|
||||
|
||||
Returns:
|
||||
(text_encoder_model, tokenizer)
|
||||
"""
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger.info(f"Loading Qwen3 text encoder from {qwen3_path}")
|
||||
|
||||
if os.path.isdir(qwen3_path):
|
||||
# Directory with full model
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
|
||||
else:
|
||||
# Single safetensors file - use configs/qwen3_06b/ for config
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
||||
if not os.path.exists(config_dir):
|
||||
raise FileNotFoundError(
|
||||
f"Qwen3 config directory not found at {config_dir}. "
|
||||
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
||||
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
||||
qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True)
|
||||
model = transformers.Qwen3ForCausalLM(qwen3_config).model
|
||||
|
||||
# Load weights
|
||||
if qwen3_path.endswith(".safetensors"):
|
||||
if lora_weights is None:
|
||||
state_dict = load_file(qwen3_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_safetensors_with_lora_and_fp8(
|
||||
model_files=qwen3_path,
|
||||
lora_weights_list=lora_weights,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=False,
|
||||
calc_device=device,
|
||||
move_to_device=True,
|
||||
dit_weight_dtype=None,
|
||||
)
|
||||
else:
|
||||
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
|
||||
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# Remove 'model.' prefix if present
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("model."):
|
||||
new_sd[k[len("model.") :]] = v
|
||||
else:
|
||||
new_sd[k] = v
|
||||
|
||||
info = model.load_state_dict(new_sd, strict=False)
|
||||
logger.info(f"Loaded Qwen3 state dict: {info}")
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model.config.use_cache = False
|
||||
model = model.requires_grad_(False).to(device, dtype=dtype)
|
||||
|
||||
logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
||||
"""Load T5 tokenizer for LLM Adapter target tokens.
|
||||
|
||||
Args:
|
||||
t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs.
|
||||
"""
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
if t5_tokenizer_path is not None:
|
||||
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
|
||||
# Use bundled config
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
|
||||
if os.path.exists(config_dir):
|
||||
return T5TokenizerFast(
|
||||
vocab_file=os.path.join(config_dir, "spiece.model"),
|
||||
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
|
||||
)
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"T5 tokenizer config directory not found at {config_dir}. "
|
||||
"Expected configs/t5_old/ with spiece.model and tokenizer.json. "
|
||||
"You can download these from the google/t5-v1_1-xxl HuggingFace repository."
|
||||
)
|
||||
|
||||
|
||||
def save_anima_model(
|
||||
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
|
||||
|
||||
Args:
|
||||
save_path: Output path (.safetensors)
|
||||
dit_state_dict: State dict from dit.state_dict()
|
||||
metadata: Metadata dict to include in the safetensors file
|
||||
dtype: Optional dtype to cast to before saving
|
||||
"""
|
||||
prefixed_sd = {}
|
||||
for k, v in dit_state_dict.items():
|
||||
if dtype is not None:
|
||||
# v = v.to(dtype)
|
||||
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
|
||||
prefixed_sd["net." + k] = v.contiguous()
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["format"] = "pt" # For compatibility with the official .safetensors file
|
||||
|
||||
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
|
||||
logger.info(f"Saved Anima model to {save_path}")
|
||||
@@ -37,6 +37,14 @@ class AttentionParams:
|
||||
cu_seqlens: Optional[torch.Tensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
|
||||
@property
|
||||
def supports_fp32(self) -> bool:
|
||||
return self.attn_mode not in ["flash"]
|
||||
|
||||
@property
|
||||
def requires_same_dtype(self) -> bool:
|
||||
return self.attn_mode in ["xformers"]
|
||||
|
||||
@staticmethod
|
||||
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
||||
return AttentionParams(attn_mode, split_attn)
|
||||
@@ -95,7 +103,7 @@ def attention(
|
||||
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
||||
k: Key tensor [B, L, H, D].
|
||||
v: Value tensor [B, L, H, D].
|
||||
attn_param: Attention parameters including mask and sequence lengths.
|
||||
attn_params: Attention parameters including mask and sequence lengths.
|
||||
drop_rate: Attention dropout rate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -108,6 +108,7 @@ class BaseDatasetParams:
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -118,7 +119,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -244,6 +245,7 @@ class ConfigSanitizer:
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -256,6 +258,7 @@ class ConfigSanitizer:
|
||||
ARGPARSE_NULLABLE_OPTNAMES = [
|
||||
"face_crop_aug_range",
|
||||
"resolution",
|
||||
"skip_image_resolution",
|
||||
]
|
||||
# prepare map because option name may differ among argparse and user config
|
||||
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
||||
@@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
skip_image_resolution: {dataset.skip_image_resolution}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def set_forward_only(self, forward_only: bool):
|
||||
# switching must wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
self.forward_only = forward_only
|
||||
|
||||
def __del__(self):
|
||||
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
|
||||
if self.debug:
|
||||
print(f"Prepare block devices before forward")
|
||||
|
||||
# wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
@@ -96,7 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -125,18 +125,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
@@ -151,7 +151,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
@@ -161,20 +161,19 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
with torch.autocast(device_type=device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -220,6 +220,8 @@ def quantize_weight(
|
||||
tensor_max = torch.max(torch.abs(tensor).view(-1))
|
||||
scale = tensor_max / max_value
|
||||
|
||||
# print(f"Optimizing {key} with scale: {scale}")
|
||||
|
||||
# numerical safety
|
||||
scale = torch.clamp(scale, min=1e-8)
|
||||
scale = scale.to(torch.float32) # ensure scale is in float32 for division
|
||||
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook=None,
|
||||
quantization_mode: str = "block",
|
||||
block_size: Optional[int] = 64,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
|
||||
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
|
||||
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
|
||||
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
|
||||
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
|
||||
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
|
||||
|
||||
Returns:
|
||||
dict: FP8 optimized state dict
|
||||
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
|
||||
# Process each file
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
|
||||
keys = f.keys()
|
||||
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
|
||||
value = f.get_tensor(key)
|
||||
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
if original_dtype.itemsize == 1:
|
||||
raise ValueError(
|
||||
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
|
||||
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
|
||||
)
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
|
||||
else:
|
||||
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
|
||||
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
|
||||
return o.to(input_dtype)
|
||||
|
||||
else:
|
||||
|
||||
522
library/leco_train_util.py
Normal file
522
library/leco_train_util.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import toml
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from library import train_util
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
|
||||
kwargs = {}
|
||||
if args.network_args:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=", 1)
|
||||
kwargs[key] = value
|
||||
if "dropout" not in kwargs:
|
||||
kwargs["dropout"] = args.network_dropout
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_save_extension(args: argparse.Namespace) -> str:
|
||||
if args.save_model_as == "ckpt":
|
||||
return ".ckpt"
|
||||
if args.save_model_as == "pt":
|
||||
return ".pt"
|
||||
return ".safetensors"
|
||||
|
||||
|
||||
def save_weights(
|
||||
accelerator,
|
||||
network,
|
||||
args: argparse.Namespace,
|
||||
save_dtype,
|
||||
prompt_settings,
|
||||
global_step: int,
|
||||
last: bool = False,
|
||||
extra_metadata: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ext = get_save_extension(args)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
metadata = None
|
||||
if not args.no_metadata:
|
||||
metadata = {
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": str(args.network_dim),
|
||||
"ss_network_alpha": str(args.network_alpha),
|
||||
"ss_leco_prompt_count": str(len(prompt_settings)),
|
||||
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
if args.training_comment:
|
||||
metadata["ss_training_comment"] = args.training_comment
|
||||
metadata["ss_leco_preview"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"target": p.target,
|
||||
"positive": p.positive,
|
||||
"unconditional": p.unconditional,
|
||||
"neutral": p.neutral,
|
||||
"action": p.action,
|
||||
"multiplier": p.multiplier,
|
||||
"weight": p.weight,
|
||||
}
|
||||
for p in prompt_settings[:16]
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
unwrapped = accelerator.unwrap_model(network)
|
||||
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
|
||||
logger.info(f"saved model to: {ckpt_file}")
|
||||
|
||||
|
||||
|
||||
ResolutionValue = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptEmbedsXL:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: torch.Tensor
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
def __init__(self):
|
||||
self.prompts: dict[str, Any] = {}
|
||||
|
||||
def __setitem__(self, name: str, value: Any) -> None:
|
||||
self.prompts[name] = value
|
||||
|
||||
def __getitem__(self, name: str) -> Any:
|
||||
return self.prompts[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptSettings:
|
||||
target: str
|
||||
positive: Optional[str] = None
|
||||
unconditional: str = ""
|
||||
neutral: Optional[str] = None
|
||||
action: str = "erase"
|
||||
guidance_scale: float = 1.0
|
||||
resolution: ResolutionValue = 512
|
||||
dynamic_resolution: bool = False
|
||||
batch_size: int = 1
|
||||
dynamic_crops: bool = False
|
||||
multiplier: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.positive is None:
|
||||
self.positive = self.target
|
||||
if self.neutral is None:
|
||||
self.neutral = self.unconditional
|
||||
if self.action not in ("erase", "enhance"):
|
||||
raise ValueError(f"Invalid action: {self.action}")
|
||||
|
||||
self.guidance_scale = float(self.guidance_scale)
|
||||
self.batch_size = int(self.batch_size)
|
||||
self.multiplier = float(self.multiplier)
|
||||
self.weight = float(self.weight)
|
||||
self.dynamic_resolution = bool(self.dynamic_resolution)
|
||||
self.dynamic_crops = bool(self.dynamic_crops)
|
||||
self.resolution = normalize_resolution(self.resolution)
|
||||
|
||||
def get_resolution(self) -> Tuple[int, int]:
|
||||
if isinstance(self.resolution, tuple):
|
||||
return self.resolution
|
||||
return (self.resolution, self.resolution)
|
||||
|
||||
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
|
||||
offset = self.guidance_scale * (positive_latents - unconditional_latents)
|
||||
if self.action == "erase":
|
||||
return neutral_latents - offset
|
||||
return neutral_latents + offset
|
||||
|
||||
|
||||
def normalize_resolution(value: Any) -> ResolutionValue:
|
||||
if isinstance(value, tuple):
|
||||
if len(value) != 2:
|
||||
raise ValueError(f"resolution tuple must have 2 items: {value}")
|
||||
return (int(value[0]), int(value[1]))
|
||||
if isinstance(value, list):
|
||||
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
|
||||
return (int(value[0]), int(value[1]))
|
||||
raise ValueError(f"resolution list must have 2 numeric items: {value}")
|
||||
return int(value)
|
||||
|
||||
|
||||
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
|
||||
def _recognized_prompt_keys() -> set[str]:
|
||||
return {
|
||||
"target",
|
||||
"positive",
|
||||
"unconditional",
|
||||
"neutral",
|
||||
"action",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _recognized_slider_keys() -> set[str]:
|
||||
return {
|
||||
"target_class",
|
||||
"positive",
|
||||
"negative",
|
||||
"neutral",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"resolutions",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
|
||||
merged = {k: v for k, v in defaults.items() if k in known_keys}
|
||||
merged.update(item)
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
|
||||
if value is None:
|
||||
return [512]
|
||||
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
|
||||
return [normalize_resolution(v) for v in value]
|
||||
return [normalize_resolution(value)]
|
||||
|
||||
|
||||
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
|
||||
target_class = str(target.get("target_class", ""))
|
||||
positive = str(target.get("positive", "") or "")
|
||||
negative = str(target.get("negative", "") or "")
|
||||
multiplier = target.get("multiplier", 1.0)
|
||||
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
|
||||
|
||||
if not positive.strip() and not negative.strip():
|
||||
raise ValueError("slider target requires either positive or negative prompt")
|
||||
|
||||
base = dict(
|
||||
target=target_class,
|
||||
neutral=neutral,
|
||||
guidance_scale=target.get("guidance_scale", 1.0),
|
||||
dynamic_resolution=target.get("dynamic_resolution", False),
|
||||
batch_size=target.get("batch_size", 1),
|
||||
dynamic_crops=target.get("dynamic_crops", False),
|
||||
weight=target.get("weight", 1.0),
|
||||
)
|
||||
|
||||
# Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs.
|
||||
# With both positive and negative: 4 pairs; with only one: 2 pairs.
|
||||
pairs: list[tuple[str, str, str, float]] = []
|
||||
if positive.strip() and negative.strip():
|
||||
pairs = [
|
||||
(negative, positive, "erase", multiplier),
|
||||
(positive, negative, "enhance", multiplier),
|
||||
(positive, negative, "erase", -multiplier),
|
||||
(negative, positive, "enhance", -multiplier),
|
||||
]
|
||||
elif negative.strip():
|
||||
pairs = [
|
||||
(negative, "", "erase", multiplier),
|
||||
(negative, "", "enhance", -multiplier),
|
||||
]
|
||||
else:
|
||||
pairs = [
|
||||
(positive, "", "enhance", multiplier),
|
||||
(positive, "", "erase", -multiplier),
|
||||
]
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
for resolution in resolutions:
|
||||
for pos, uncond, action, mult in pairs:
|
||||
prompt_settings.append(
|
||||
PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult)
|
||||
)
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
|
||||
path = Path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
|
||||
if not data:
|
||||
raise ValueError("prompt file is empty")
|
||||
|
||||
default_prompt_values = {
|
||||
"guidance_scale": 1.0,
|
||||
"resolution": 512,
|
||||
"dynamic_resolution": False,
|
||||
"batch_size": 1,
|
||||
"dynamic_crops": False,
|
||||
"multiplier": 1.0,
|
||||
"weight": 1.0,
|
||||
}
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
|
||||
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
|
||||
prompt_settings.append(PromptSettings(**merged))
|
||||
|
||||
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
|
||||
if not neutral_values:
|
||||
neutral_values = [str(merged.get("neutral", "") or "")]
|
||||
for neutral in neutral_values:
|
||||
prompt_settings.extend(_expand_slider_target(merged, neutral))
|
||||
|
||||
if "prompts" in data:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
|
||||
for item in data["prompts"]:
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, defaults, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
slider_config = data.get("slider", data)
|
||||
targets = slider_config.get("targets")
|
||||
if targets is None:
|
||||
if "target_class" in slider_config:
|
||||
targets = [slider_config]
|
||||
elif "target" in slider_config:
|
||||
targets = [slider_config]
|
||||
else:
|
||||
raise ValueError("prompt file does not contain prompts or slider targets")
|
||||
if len(targets) == 0:
|
||||
raise ValueError("prompt file contains an empty targets list")
|
||||
|
||||
if "target" in targets[0]:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
|
||||
for item in targets:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
|
||||
neutral_values: List[str] = []
|
||||
if "neutrals" in slider_config:
|
||||
neutral_values.extend(str(v) for v in slider_config["neutrals"])
|
||||
if "neutral_prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
|
||||
if "prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
|
||||
if not neutral_values:
|
||||
neutral_values = [str(slider_config.get("neutral", "") or "")]
|
||||
|
||||
for item in targets:
|
||||
item_neutrals = neutral_values
|
||||
if "neutrals" in item:
|
||||
item_neutrals = [str(v) for v in item["neutrals"]]
|
||||
elif "neutral_prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
|
||||
elif "prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
|
||||
elif "neutral" in item:
|
||||
item_neutrals = [str(item["neutral"] or "")]
|
||||
|
||||
append_slider_item(item, defaults, item_neutrals)
|
||||
|
||||
if not prompt_settings:
|
||||
raise ValueError("no prompt settings found")
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
|
||||
|
||||
|
||||
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
|
||||
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
|
||||
|
||||
|
||||
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
|
||||
if noise_offset is None:
|
||||
return latents
|
||||
noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu")
|
||||
noise = noise.to(dtype=latents.dtype, device=latents.device)
|
||||
return latents + noise_offset * noise
|
||||
|
||||
|
||||
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
|
||||
noise = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
device="cpu",
|
||||
).repeat(n_prompts, 1, 1, 1)
|
||||
return noise * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
|
||||
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
|
||||
|
||||
|
||||
def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
"""Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch."""
|
||||
return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def _run_with_checkpoint(function, *args):
|
||||
if torch.is_grad_enabled():
|
||||
return checkpoint(function, *args, use_reentrant=False)
|
||||
return function(*args)
|
||||
|
||||
|
||||
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
def run_unet(model_input, encoder_hidden_states):
|
||||
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_add_time_ids(
|
||||
height: int,
|
||||
width: int,
|
||||
dynamic_crops: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if dynamic_crops:
|
||||
random_scale = torch.rand(1).item() * 2 + 1
|
||||
original_size = (int(height * random_scale), int(width * random_scale))
|
||||
crops_coords_top_left = (
|
||||
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
|
||||
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
|
||||
)
|
||||
target_size = (height, width)
|
||||
else:
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = (height, width)
|
||||
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
|
||||
if device is not None:
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
def predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
guidance_scale: float = 1.0,
|
||||
):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
orig_size = add_time_ids[:, :2]
|
||||
crop_size = add_time_ids[:, 2:4]
|
||||
target_size = add_time_ids[:, 4:6]
|
||||
from library import sdxl_train_util
|
||||
|
||||
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
|
||||
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
|
||||
|
||||
def run_unet(model_input, text_embeds, vector_embeds):
|
||||
return unet(model_input, timestep, text_embeds, vector_embeds)
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
add_time_ids=add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
|
||||
max_resolution = bucket_resolution
|
||||
min_resolution = bucket_resolution // 2
|
||||
step = 64
|
||||
min_step = min_resolution // step
|
||||
max_step = max_resolution // step
|
||||
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
return height, width
|
||||
|
||||
|
||||
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
|
||||
height, width = prompt.get_resolution()
|
||||
if prompt.dynamic_resolution and height == width:
|
||||
return get_random_resolution_in_bucket(height)
|
||||
return height, width
|
||||
@@ -1,246 +1,287 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
basename = os.path.basename(model_file)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(model_file), filename)
|
||||
if os.path.exists(filepath):
|
||||
extended_model_files.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
lora_name = "lora_unet_" + lora_name.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
|
||||
continue
|
||||
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
||||
from networks.loha import merge_weights_to_tensor as loha_merge
|
||||
from networks.lokr import merge_weights_to_tensor as lokr_merge
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
||||
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
split_filenames = get_split_weight_filenames(model_file)
|
||||
if split_filenames is not None:
|
||||
extended_model_files.extend(split_filenames)
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
found = False
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# Standard LoRA merge
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
# temporarily convert to float16 for calculation
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
down_weight = down_weight.to(torch.float16)
|
||||
up_weight = up_weight.to(torch.float16)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
continue
|
||||
|
||||
# Check for LoHa/LoKr weights with same prefix search
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
hada_key = lora_name + ".hada_w1_a"
|
||||
lokr_key = lora_name + ".lokr_w1"
|
||||
|
||||
if hada_key in lora_weight_keys:
|
||||
# LoHa merge
|
||||
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
||||
break
|
||||
elif lokr_key in lora_weight_keys:
|
||||
# LoKr merge
|
||||
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
||||
break
|
||||
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files,
|
||||
calc_device,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
move_to_device=move_to_device,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
|
||||
@@ -34,18 +34,18 @@ from library import custom_offloading_utils
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
except:
|
||||
except ImportError:
|
||||
# flash_attn may not be available but it is not required
|
||||
pass
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except:
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
@@ -98,7 +98,7 @@ except:
|
||||
x_dtype = x.dtype
|
||||
# To handle float8 we need to convert the tensor to float
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
|
||||
|
||||
@@ -370,7 +370,7 @@ class JointAttention(nn.Module):
|
||||
if self.use_sage_attn:
|
||||
# Handle GQA (Grouped Query Attention) if needed
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@@ -379,7 +379,7 @@ class JointAttention(nn.Module):
|
||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@@ -456,51 +456,47 @@ class JointAttention(nn.Module):
|
||||
bsz = q.shape[0]
|
||||
seqlen = q.shape[1]
|
||||
|
||||
# Transpose tensors to match SageAttention's expected format (HND layout)
|
||||
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
|
||||
# Handle masking for SageAttention
|
||||
# We need to filter out masked positions - this approach handles variable sequence lengths
|
||||
outputs = []
|
||||
for b in range(bsz):
|
||||
# Find valid token positions from the mask
|
||||
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
|
||||
if valid_indices.numel() == 0:
|
||||
# If all tokens are masked, create a zero output
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
)
|
||||
else:
|
||||
# Extract only valid tokens for this batch
|
||||
batch_q = q_transposed[b, :, valid_indices, :]
|
||||
batch_k = k_transposed[b, :, valid_indices, :]
|
||||
batch_v = v_transposed[b, :, valid_indices, :]
|
||||
|
||||
# Run SageAttention on valid tokens only
|
||||
# Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
|
||||
q_transposed = q.permute(0, 2, 1, 3)
|
||||
k_transposed = k.permute(0, 2, 1, 3)
|
||||
v_transposed = v.permute(0, 2, 1, 3)
|
||||
|
||||
# Fast path: if all tokens are valid, run batched SageAttention directly
|
||||
if x_mask.all():
|
||||
output = sageattn(
|
||||
q_transposed, k_transposed, v_transposed,
|
||||
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||
)
|
||||
# output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Slow path: per-batch loop to handle variable-length masking
|
||||
# SageAttention does not support attention masks natively
|
||||
outputs = []
|
||||
for b in range(bsz):
|
||||
valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
|
||||
if valid_indices.numel() == 0:
|
||||
outputs.append(torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype,
|
||||
))
|
||||
continue
|
||||
|
||||
batch_output_valid = sageattn(
|
||||
batch_q.unsqueeze(0), # Add batch dimension back
|
||||
batch_k.unsqueeze(0),
|
||||
batch_v.unsqueeze(0),
|
||||
tensor_layout="HND",
|
||||
is_causal=False,
|
||||
sm_scale=softmax_scale
|
||||
q_transposed[b:b+1, :, valid_indices, :],
|
||||
k_transposed[b:b+1, :, valid_indices, :],
|
||||
v_transposed[b:b+1, :, valid_indices, :],
|
||||
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||
)
|
||||
|
||||
# Create output tensor with zeros for masked positions
|
||||
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype,
|
||||
)
|
||||
# Place valid outputs back in the right positions
|
||||
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
# Stack batch outputs and reshape to expected format
|
||||
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
|
||||
outputs.append(batch_output)
|
||||
|
||||
output = torch.stack(outputs, dim=0)
|
||||
except NameError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
||||
@@ -1113,10 +1109,9 @@ class NextDiT(nn.Module):
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
|
||||
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
for i in range(bsz):
|
||||
x[i, :image_seq_len] = x[i]
|
||||
x_mask[i, :image_seq_len] = True
|
||||
# x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
|
||||
# The mask can be set without a loop since all samples have the same image_seq_len.
|
||||
x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
|
||||
@@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
|
||||
axes_dims=[40, 40, 40],
|
||||
axes_lens=[300, 512, 512],
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -334,32 +334,35 @@ def sample_image_inference(
|
||||
|
||||
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
||||
|
||||
# Get sample prompts from cache
|
||||
# Get sample prompts from cache, fallback to live encoding
|
||||
gemma2_conds = None
|
||||
neg_gemma2_conds = None
|
||||
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
|
||||
if (
|
||||
sample_prompts_gemma2_outputs
|
||||
and negative_prompt in sample_prompts_gemma2_outputs
|
||||
):
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(
|
||||
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
|
||||
)
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
# Only encode if not found in cache
|
||||
if gemma2_conds is None and gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
if neg_gemma2_conds is None and gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
if gemma2_conds is None or neg_gemma2_conds is None:
|
||||
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
|
||||
continue
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
||||
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||
@@ -475,6 +478,7 @@ def sample_image_inference(
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
"""Apply time shifting to timesteps."""
|
||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
return t
|
||||
|
||||
@@ -483,7 +487,7 @@ def get_lin_function(
|
||||
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
"""
|
||||
Get linear function
|
||||
Get linear function for resolution-dependent shifting.
|
||||
|
||||
Args:
|
||||
image_seq_len,
|
||||
@@ -528,6 +532,7 @@ def get_schedule(
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
|
||||
image_seq_len
|
||||
)
|
||||
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
@@ -689,15 +694,15 @@ def denoise(
|
||||
|
||||
img_dtype = img.dtype
|
||||
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
img = img.to(img_dtype)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = -noise_pred
|
||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
img = img.to(img_dtype)
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
@@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "nextdit_shift":
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
|
||||
@@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
|
||||
1735
library/qwen_image_autoencoder_kl.py
Normal file
1735
library/qwen_image_autoencoder_kl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
# print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
|
||||
by using memory mapping for large tensors and avoiding unnecessary copies.
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
def __init__(self, filename, disable_numpy_memmap=False):
|
||||
"""Initialize the SafeTensor reader.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the safetensors file to read.
|
||||
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
|
||||
"""
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.disable_numpy_memmap = disable_numpy_memmap
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
|
||||
# Use memmap for large tensors to avoid intermediate copies.
|
||||
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
|
||||
# So we only use memmap if device is not cpu.
|
||||
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
|
||||
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# Create memory map for zero-copy reading
|
||||
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
|
||||
byte_tensor = torch.from_numpy(mm) # zero copy
|
||||
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
path: str,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
@@ -293,7 +302,7 @@ def load_safetensors(
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
device = torch.device(device) if device is not None else None
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
|
||||
synchronize_device(device)
|
||||
@@ -309,6 +318,29 @@ def load_safetensors(
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
|
||||
"""
|
||||
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
|
||||
Returns None if the file is not split.
|
||||
"""
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
filenames = []
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
filenames.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
return filenames
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_split_weights(
|
||||
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
@@ -319,19 +351,11 @@ def load_split_weights(
|
||||
device = torch.device(device)
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
split_filenames = get_split_weight_filenames(file_path)
|
||||
if split_filenames is not None:
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
for filename in split_filenames:
|
||||
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
return state_dict
|
||||
@@ -349,3 +373,106 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
|
||||
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransformHooks:
|
||||
split_hook: Optional[callable] = None
|
||||
concat_hook: Optional[callable] = None
|
||||
rename_hook: Optional[callable] = None
|
||||
|
||||
|
||||
class TensorWeightAdapter:
|
||||
"""
|
||||
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
|
||||
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
|
||||
when loading tensors.
|
||||
|
||||
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
|
||||
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
|
||||
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
|
||||
|
||||
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
|
||||
|
||||
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
|
||||
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
|
||||
|
||||
**concat_hook is not tested yet.**
|
||||
"""
|
||||
|
||||
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
|
||||
self.original_f = original_f
|
||||
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
|
||||
{}
|
||||
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
|
||||
self.concat_key_set = set() # set of concatenated keys
|
||||
self.split_key_set = set() # set of split keys
|
||||
self.new_keys = []
|
||||
self.tensor_cache = {} # cache for split tensors
|
||||
self.split_hook = weight_convert_hook.split_hook
|
||||
self.concat_hook = weight_convert_hook.concat_hook
|
||||
self.rename_hook = weight_convert_hook.rename_hook
|
||||
|
||||
for key in self.original_f.keys():
|
||||
if self.split_hook is not None:
|
||||
converted_keys, _ = self.split_hook(key, None) # get new keys only
|
||||
if converted_keys is not None:
|
||||
for converted_key in converted_keys:
|
||||
self.new_key_to_original_key_map[converted_key] = key
|
||||
self.split_key_set.add(converted_key)
|
||||
self.new_keys.extend(converted_keys)
|
||||
continue # skip concat_hook if split_hook is applied
|
||||
|
||||
if self.concat_hook is not None:
|
||||
converted_key, _ = self.concat_hook(key, None) # get new key only
|
||||
if converted_key is not None:
|
||||
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
|
||||
self.concat_key_set.add(converted_key)
|
||||
self.new_key_to_original_key_map[converted_key] = []
|
||||
self.new_keys.append(converted_key)
|
||||
|
||||
# multiple original keys map to the same concatenated key
|
||||
self.new_key_to_original_key_map[converted_key].append(key)
|
||||
continue # skip to next key
|
||||
|
||||
# direct mapping
|
||||
if self.rename_hook is not None:
|
||||
new_key = self.rename_hook(key)
|
||||
self.new_key_to_original_key_map[new_key] = key
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
self.new_keys.append(new_key)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return self.new_keys
|
||||
|
||||
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
# load tensor by new_key, applying split or concat hooks as needed
|
||||
if new_key not in self.new_key_to_original_key_map:
|
||||
# direct mapping
|
||||
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
|
||||
|
||||
elif new_key in self.split_key_set:
|
||||
# split hook: split key is requested multiple times, so we cache the result
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
if original_key not in self.tensor_cache: # not yet split
|
||||
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
|
||||
for k, t in zip(new_keys, new_tensors):
|
||||
self.tensor_cache[k] = t
|
||||
return self.tensor_cache.pop(new_key) # return and remove from cache
|
||||
|
||||
elif new_key in self.concat_key_set:
|
||||
# concat hook: concatenated key is requested only once, so we do not cache the result
|
||||
tensors = {}
|
||||
for original_key in self.new_key_to_original_key_map[new_key]:
|
||||
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
tensors[original_key] = tensor
|
||||
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
|
||||
return concatenated_tensors
|
||||
|
||||
else:
|
||||
# direct mapping
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
|
||||
@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
|
||||
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
|
||||
ARCH_ANIMA_PREVIEW = "anima-preview"
|
||||
ARCH_ANIMA_UNKNOWN = "anima-unknown"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
|
||||
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -220,6 +223,12 @@ def determine_architecture(
|
||||
arch = ARCH_HUNYUAN_IMAGE_2_1
|
||||
else:
|
||||
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
|
||||
elif "anima" in model_config:
|
||||
anima_type = model_config["anima"]
|
||||
if anima_type == "preview":
|
||||
arch = ARCH_ANIMA_PREVIEW
|
||||
else:
|
||||
arch = ARCH_ANIMA_UNKNOWN
|
||||
elif v2:
|
||||
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
|
||||
else:
|
||||
@@ -252,6 +261,8 @@ def determine_implementation(
|
||||
return IMPL_FLUX
|
||||
elif "lumina" in model_config:
|
||||
return IMPL_LUMINA
|
||||
elif "anima" in model_config:
|
||||
return IMPL_ANIMA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
return IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -325,7 +336,7 @@ def determine_resolution(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# Determine default resolution based on model type
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
|
||||
reso = (1024, 1024)
|
||||
elif v2 and v_parameterization:
|
||||
reso = (768, 768)
|
||||
|
||||
302
library/strategy_anima.py
Normal file
302
library/strategy_anima.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Anima Strategy Classes
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library import anima_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
from library import qwen_image_autoencoder_kl
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||
"""Tokenize strategy for Anima: dual tokenization with Qwen3 + T5.
|
||||
|
||||
Qwen3 tokens are used for the text encoder.
|
||||
T5 tokens are used as target input IDs for the LLM Adapter (NOT encoded by T5).
|
||||
|
||||
Can be initialized with either pre-loaded tokenizer objects or paths to load from.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qwen3_tokenizer=None,
|
||||
t5_tokenizer=None,
|
||||
qwen3_max_length: int = 512,
|
||||
t5_max_length: int = 512,
|
||||
qwen3_path: Optional[str] = None,
|
||||
t5_tokenizer_path: Optional[str] = None,
|
||||
) -> None:
|
||||
# Load tokenizers from paths if not provided directly
|
||||
if qwen3_tokenizer is None:
|
||||
if qwen3_path is None:
|
||||
raise ValueError("Either qwen3_tokenizer or qwen3_path must be provided")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(qwen3_path)
|
||||
if t5_tokenizer is None:
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
|
||||
|
||||
self.qwen3_tokenizer = qwen3_tokenizer
|
||||
self.qwen3_max_length = qwen3_max_length
|
||||
self.t5_tokenizer = t5_tokenizer
|
||||
self.t5_max_length = t5_max_length
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
# Tokenize with Qwen3
|
||||
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
|
||||
)
|
||||
qwen3_input_ids = qwen3_encoding["input_ids"]
|
||||
qwen3_attn_mask = qwen3_encoding["attention_mask"]
|
||||
|
||||
# Tokenize with T5 (for LLM Adapter target tokens)
|
||||
t5_encoding = self.t5_tokenizer.batch_encode_plus(
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
|
||||
)
|
||||
t5_input_ids = t5_encoding["input_ids"]
|
||||
t5_attn_mask = t5_encoding["attention_mask"]
|
||||
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
"""Text encoding strategy for Anima.
|
||||
|
||||
Encodes Qwen3 tokens through the Qwen3 text encoder to get hidden states.
|
||||
T5 tokens are passed through unchanged (only used by LLM Adapter).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||
|
||||
Args:
|
||||
models: [qwen3_text_encoder]
|
||||
tokens: [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
Returns:
|
||||
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
"""
|
||||
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
|
||||
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
|
||||
encoder_device = qwen3_text_encoder.device
|
||||
|
||||
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
|
||||
prompt_embeds = outputs.last_hidden_state
|
||||
prompt_embeds[~qwen3_attn_mask.bool()] = 0
|
||||
|
||||
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
prompt_embeds: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
t5_input_ids: torch.Tensor,
|
||||
t5_attn_mask: torch.Tensor,
|
||||
caption_dropout_rates: Optional[torch.Tensor] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Apply dropout to cached text encoder outputs.
|
||||
|
||||
Called during training when using cached outputs.
|
||||
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
|
||||
to match diffusion-pipe-main behavior.
|
||||
"""
|
||||
if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
# Clone to avoid in-place modification of cached tensors
|
||||
prompt_embeds = prompt_embeds.clone()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.clone()
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
|
||||
for i in range(prompt_embeds.shape[0]):
|
||||
if random.random() < caption_dropout_rates[i].item():
|
||||
# Use pre-cached unconditional embeddings
|
||||
prompt_embeds[i] = 0
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = 0
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i, 0] = 1 # Set to </s> token ID
|
||||
t5_input_ids[i, 1:] = 0
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i, 0] = 1
|
||||
t5_attn_mask[i, 1:] = 0
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
"""Caching strategy for Anima text encoder outputs.
|
||||
|
||||
Caches: prompt_embeds (float), attn_mask (int), t5_input_ids (int), t5_attn_mask (int)
|
||||
"""
|
||||
|
||||
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "prompt_embeds" not in npz:
|
||||
return False
|
||||
if "attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_input_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "caption_dropout_rate" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
prompt_embeds = data["prompt_embeds"]
|
||||
attn_mask = data["attn_mask"]
|
||||
t5_input_ids = data["t5_input_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
caption_dropout_rate = data["caption_dropout_rate"]
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
infos: List,
|
||||
):
|
||||
anima_text_encoding_strategy: AnimaTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks
|
||||
)
|
||||
|
||||
# Convert to numpy for caching
|
||||
if prompt_embeds.dtype == torch.bfloat16:
|
||||
prompt_embeds = prompt_embeds.float()
|
||||
prompt_embeds = prompt_embeds.cpu().numpy()
|
||||
attn_mask = attn_mask.cpu().numpy()
|
||||
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
prompt_embeds_i = prompt_embeds[i]
|
||||
attn_mask_i = attn_mask[i]
|
||||
t5_input_ids_i = t5_input_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
prompt_embeds=prompt_embeds_i,
|
||||
attn_mask=attn_mask_i,
|
||||
t5_input_ids=t5_input_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
|
||||
|
||||
|
||||
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
"""Latent caching strategy for Anima using WanVAE.
|
||||
|
||||
WanVAE produces 16-channel latents with spatial downscale 8x.
|
||||
Latent shape for images: (B, 16, 1, H/8, W/8)
|
||||
"""
|
||||
|
||||
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
"""Cache batch of latents using Qwen Image VAE.
|
||||
|
||||
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
|
||||
The encoding function handles the mean/std normalization.
|
||||
"""
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
def encode_by_vae(img_tensor):
|
||||
"""Encode image tensor to latents.
|
||||
|
||||
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
|
||||
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
|
||||
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
|
||||
"""
|
||||
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
|
||||
return latents.to("cpu")
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae_device)
|
||||
@@ -382,6 +382,8 @@ class LatentsCachingStrategy:
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
_warned_fallback_to_old_npz = False # to avoid spamming logs about fallback
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
@@ -459,11 +461,14 @@ class LatentsCachingStrategy:
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
|
||||
# In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility)
|
||||
# In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files).
|
||||
if "latents" + key_reso_suffix not in npz and "latents" not in npz:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz):
|
||||
return False
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
@@ -495,8 +500,8 @@ class LatentsCachingStrategy:
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
random_crop: whether to random crop images
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
@@ -524,7 +529,7 @@ class LatentsCachingStrategy:
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
latents_size = latents.shape[1:3] # H, W
|
||||
latents_size = latents.shape[-2:] # H, W (supports both 4D and 5D latents)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||
|
||||
if self.cache_to_disk:
|
||||
@@ -543,18 +548,18 @@ class LatentsCachingStrategy:
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
For single resolution architectures (currently no architecture is single resolution specific). Kept for reference.
|
||||
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
@@ -568,25 +573,34 @@ class LatentsCachingStrategy:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
# raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
# Fallback to old npz without resolution suffix
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})")
|
||||
if not self._warned_fallback_to_old_npz:
|
||||
logger.warning(
|
||||
f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version."
|
||||
)
|
||||
self._warned_fallback_to_old_npz = True
|
||||
key_reso_suffix = ""
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
|
||||
@@ -2,6 +2,7 @@ import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
@@ -144,7 +145,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
@@ -157,7 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -165,7 +171,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -179,12 +179,15 @@ def split_train_val(
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
def __init__(
|
||||
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
|
||||
) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.caption: str = caption
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.caption_dropout_rate: float = caption_dropout_rate
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
@@ -197,7 +200,7 @@ class ImageInfo:
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
@@ -684,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -724,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.skip_image_resolution = skip_image_resolution
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
@@ -1096,11 +1102,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self):
|
||||
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0
|
||||
and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -1912,8 +1919,15 @@ class DreamBoothDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -2031,6 +2045,24 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
filtered_img_paths = []
|
||||
filtered_sizes = []
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
for img_path, size in zip(img_paths, sizes):
|
||||
if size is None: # no latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(img_path)
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
continue
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(
|
||||
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
|
||||
)
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
@@ -2056,7 +2088,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
captions = [metas[img_path]["caption"] for img_path in img_paths]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
@@ -2137,7 +2169,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
|
||||
info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
@@ -2197,8 +2229,15 @@ class FineTuningDataset(BaseDataset):
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
@@ -2294,6 +2333,7 @@ class FineTuningDataset(BaseDataset):
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
num_filtered = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
@@ -2338,7 +2378,7 @@ class FineTuningDataset(BaseDataset):
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
@@ -2352,6 +2392,16 @@ class FineTuningDataset(BaseDataset):
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
size = image_info.image_size
|
||||
if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(abs_path)
|
||||
image_info.image_size = size
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
num_filtered += 1
|
||||
continue
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
@@ -2360,6 +2410,8 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
if num_filtered > 0:
|
||||
logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2384,8 +2436,15 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2437,6 +2496,7 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split,
|
||||
validation_seed,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
@@ -2484,9 +2544,8 @@ class ControlNetDataset(BaseDataset):
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -2661,8 +2720,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
|
||||
|
||||
def set_current_strategies(self):
|
||||
for dataset in self.datasets:
|
||||
@@ -3578,6 +3637,7 @@ def get_sai_model_spec_dataclass(
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
hunyuan_image: str = None,
|
||||
anima: str = None,
|
||||
optional_metadata: dict[str, str] | None = None,
|
||||
) -> sai_model_spec.ModelSpecMetadata:
|
||||
"""
|
||||
@@ -3609,7 +3669,8 @@ def get_sai_model_spec_dataclass(
|
||||
model_config["lumina"] = lumina
|
||||
if hunyuan_image is not None:
|
||||
model_config["hunyuan_image"] = hunyuan_image
|
||||
|
||||
if anima is not None:
|
||||
model_config["anima"] = anima
|
||||
# Use the dataclass function directly
|
||||
return sai_model_spec.build_metadata_dataclass(
|
||||
state_dict,
|
||||
@@ -4596,6 +4657,13 @@ def add_dataset_arguments(
|
||||
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_image_resolution",
|
||||
type=str,
|
||||
default=None,
|
||||
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
|
||||
" / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
@@ -5409,6 +5477,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
len(args.resolution) == 2
|
||||
), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||
|
||||
if args.skip_image_resolution is not None:
|
||||
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
|
||||
if len(args.skip_image_resolution) == 1:
|
||||
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
|
||||
assert (
|
||||
len(args.skip_image_resolution) == 2
|
||||
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}"
|
||||
|
||||
if args.face_crop_aug_range is not None:
|
||||
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
|
||||
assert (
|
||||
@@ -6138,7 +6214,8 @@ def conditional_loss(
|
||||
elif loss_type == "huber":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -6147,7 +6224,8 @@ def conditional_loss(
|
||||
elif loss_type == "smooth_l1":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -6176,10 +6254,14 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
|
||||
name = names[lr_index]
|
||||
logs["lr/" + name] = float(lrs[lr_index])
|
||||
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
logs["lr/d*lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||
)
|
||||
if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]:
|
||||
logs["lr/d*eff_lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
|
||||
)
|
||||
|
||||
|
||||
# scheduler:
|
||||
@@ -6210,6 +6292,32 @@ def get_my_scheduler(
|
||||
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = sample_sampler
|
||||
elif sample_sampler == "dpmsolver++_2m":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
elif sample_sampler == "dpmsolver++_2m_lu":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["use_lu_lambdas"] = True
|
||||
elif sample_sampler == "dpmsolver++_2m_k":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["use_karras_sigmas"] = True
|
||||
elif sample_sampler == "dpmsolver++_2m_stable":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["euler_at_final"] = True
|
||||
elif sample_sampler == "dpmsolver++_2m_sde":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = "sde-dpmsolver++"
|
||||
elif sample_sampler == "dpmsolver++_2m_sde_k":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = "sde-dpmsolver++"
|
||||
sched_init_args["use_karras_sigmas"] = True
|
||||
elif sample_sampler == "dpmsolver++_2m_sde_lu":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = "sde-dpmsolver++"
|
||||
sched_init_args["use_lu_lambdas"] = True
|
||||
elif sample_sampler == "dpmsolver++_2m_sde_stable":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = "sde-dpmsolver++"
|
||||
sched_init_args["euler_at_final"] = True
|
||||
elif sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif sample_sampler == "heun":
|
||||
|
||||
@@ -370,19 +370,25 @@ def train(args):
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
# Lumina NextDiT architecture:
|
||||
# - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
|
||||
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
|
||||
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
|
||||
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
|
||||
block_type = "other"
|
||||
if np[0].startswith("layers."):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
block_type = "main"
|
||||
elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
|
||||
# All refiner blocks (context + noise) grouped together
|
||||
block_index = -1
|
||||
block_type = "refiner"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
@@ -759,7 +765,7 @@ def train(args):
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(
|
||||
args, timesteps, noise_scheduler
|
||||
args, 1000 - timesteps, noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
|
||||
@@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
train_dataset_group.verify_bucket_reso_steps(16)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
val_dataset_group.verify_bucket_reso_steps(16)
|
||||
|
||||
self.train_gemma2 = not args.network_train_unet_only
|
||||
|
||||
@@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
# Lumina uses a single text encoder (Gemma2) at index 0.
|
||||
# Check original dtype BEFORE casting to preserve fp8 detection.
|
||||
gemma2_original_dtype = text_encoders[0].dtype
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
if gemma2_original_dtype == torch.float8_e4m3fn:
|
||||
# Model was loaded as fp8 — apply fp8 optimization
|
||||
self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
# Otherwise, cast to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
|
||||
160
networks/convert_anima_lora_to_comfy.py
Normal file
160
networks/convert_anima_lora_to_comfy.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import argparse
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
from library import train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMFYUI_DIT_PREFIX = "diffusion_model."
|
||||
COMFYUI_QWEN3_PREFIX = "text_encoders.qwen3_06b.transformer.model."
|
||||
|
||||
|
||||
def main(args):
|
||||
# load source safetensors
|
||||
logger.info(f"Loading source file {args.src_path}")
|
||||
state_dict = {}
|
||||
with safe_open(args.src_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
logger.info(f"Converting...")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
count = 0
|
||||
|
||||
for k in keys:
|
||||
if not args.reverse:
|
||||
is_dit_lora = k.startswith("lora_unet_")
|
||||
module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
|
||||
|
||||
# Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
|
||||
module_name, weight_name = module_and_weight_name.split(".", 1)
|
||||
|
||||
# Weight name conversion: lora_up/lora_down to lora_A/lora_B
|
||||
if weight_name.startswith("lora_up"):
|
||||
weight_name = weight_name.replace("lora_up", "lora_B")
|
||||
elif weight_name.startswith("lora_down"):
|
||||
weight_name = weight_name.replace("lora_down", "lora_A")
|
||||
else:
|
||||
# Keep other weight names as-is: e.g. alpha
|
||||
pass
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
original_module_name = module_name.replace("_", ".") # Convert to dot notation
|
||||
|
||||
# Convert back illegal dots in module names
|
||||
# DiT
|
||||
original_module_name = original_module_name.replace("llm.adapter", "llm_adapter")
|
||||
original_module_name = original_module_name.replace(".linear.", ".linear_")
|
||||
original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
|
||||
original_module_name = original_module_name.replace("x.embedder", "x_embedder")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
|
||||
original_module_name = original_module_name.replace("cross.attn", "cross_attn")
|
||||
original_module_name = original_module_name.replace("k.proj", "k_proj")
|
||||
original_module_name = original_module_name.replace("k.norm", "k_norm")
|
||||
original_module_name = original_module_name.replace("q.proj", "q_proj")
|
||||
original_module_name = original_module_name.replace("q.norm", "q_norm")
|
||||
original_module_name = original_module_name.replace("v.proj", "v_proj")
|
||||
original_module_name = original_module_name.replace("o.proj", "o_proj")
|
||||
original_module_name = original_module_name.replace("output.proj", "output_proj")
|
||||
original_module_name = original_module_name.replace("self.attn", "self_attn")
|
||||
original_module_name = original_module_name.replace("final.layer", "final_layer")
|
||||
original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
|
||||
original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
|
||||
original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
|
||||
original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
|
||||
original_module_name = original_module_name.replace("out.proj", "out_proj")
|
||||
|
||||
# Qwen3
|
||||
original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
|
||||
original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
|
||||
original_module_name = original_module_name.replace("down.proj", "down_proj")
|
||||
original_module_name = original_module_name.replace("gate.proj", "gate_proj")
|
||||
original_module_name = original_module_name.replace("up.proj", "up_proj")
|
||||
original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
|
||||
|
||||
# Prefix conversion
|
||||
new_prefix = COMFYUI_DIT_PREFIX if is_dit_lora else COMFYUI_QWEN3_PREFIX
|
||||
|
||||
new_k = f"{new_prefix}{original_module_name}.{weight_name}"
|
||||
else:
|
||||
if k.startswith(COMFYUI_DIT_PREFIX):
|
||||
is_dit_lora = True
|
||||
module_and_weight_name = k[len(COMFYUI_DIT_PREFIX) :]
|
||||
elif k.startswith(COMFYUI_QWEN3_PREFIX):
|
||||
is_dit_lora = False
|
||||
module_and_weight_name = k[len(COMFYUI_QWEN3_PREFIX) :]
|
||||
else:
|
||||
logger.warning(f"Skipping unrecognized key {k}")
|
||||
continue
|
||||
|
||||
# Get weight name
|
||||
if ".lora_" in module_and_weight_name:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
|
||||
weight_name = "lora_" + weight_name
|
||||
else:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
|
||||
|
||||
# Weight name conversion: lora_A/lora_B to lora_up/lora_down
|
||||
# Note: we only convert lora_A and lora_B weights, other weights are kept as-is
|
||||
if weight_name.startswith("lora_B"):
|
||||
weight_name = weight_name.replace("lora_B", "lora_up")
|
||||
elif weight_name.startswith("lora_A"):
|
||||
weight_name = weight_name.replace("lora_A", "lora_down")
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
module_name = module_name.replace(".", "_") # Convert to underscore notation
|
||||
|
||||
# Prefix conversion
|
||||
prefix = "lora_unet_" if is_dit_lora else "lora_te_"
|
||||
|
||||
new_k = f"{prefix}{module_name}.{weight_name}"
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Converted {count} keys")
|
||||
if count == 0:
|
||||
logger.warning("No keys were converted. Please check if the source file is in the expected format.")
|
||||
elif count > 0 and count < len(keys):
|
||||
logger.warning(
|
||||
f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
|
||||
)
|
||||
|
||||
# Calculate hash
|
||||
if metadata is not None:
|
||||
logger.info(f"Calculating hashes and creating metadata...")
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
# save destination safetensors
|
||||
logger.info(f"Saving destination file {args.dst_path}")
|
||||
save_file(state_dict, args.dst_path, metadata=metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA format")
|
||||
parser.add_argument(
|
||||
"src_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"dst_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
|
||||
)
|
||||
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
643
networks/loha.py
Normal file
643
networks/loha.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# LoHa (Low-rank Hadamard Product) network module
|
||||
# Reference: https://arxiv.org/abs/2108.06098
|
||||
#
|
||||
# Based on the LyCORIS project by KohakuBlueleaf
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS
|
||||
|
||||
import ast
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
"""Efficient Hadamard product forward/backward for LoHa.
|
||||
|
||||
Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
|
||||
that recomputes intermediates instead of storing them.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
|
||||
if scale is None:
|
||||
scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
|
||||
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2a @ w2b)
|
||||
grad_w1a = temp @ w1b.T
|
||||
grad_w1b = w1a.T @ temp
|
||||
|
||||
temp = grad_out * (w1a @ w1b)
|
||||
grad_w2a = temp @ w2b.T
|
||||
grad_w2b = w2a.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
"""Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
|
||||
|
||||
Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
|
||||
where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
|
||||
Compatible with LyCORIS parameter naming convention.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
|
||||
if scale is None:
|
||||
scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
|
||||
ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
# Gradients for w1a, w1b, t1 (using rebuild2)
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
|
||||
del grad_temp
|
||||
|
||||
# Gradients for w2a, w2b, t2 (using rebuild1)
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
|
||||
del grad_temp
|
||||
|
||||
return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
|
||||
|
||||
|
||||
class LoHaModule(torch.nn.Module):
|
||||
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
use_tucker=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
kernel_size = org_module.kernel_size
|
||||
self.is_conv = True
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.dilation = org_module.dilation
|
||||
self.groups = org_module.groups
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
||||
|
||||
if kernel_size == (1, 1):
|
||||
self.conv_mode = "1x1"
|
||||
elif self.tucker:
|
||||
self.conv_mode = "tucker"
|
||||
else:
|
||||
self.conv_mode = "flat"
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
self.tucker = False
|
||||
self.conv_mode = None
|
||||
self.kernel_size = None
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Create parameters based on mode
|
||||
if self.conv_mode == "tucker":
|
||||
# Tucker decomposition for Conv2d 3x3+
|
||||
# Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
|
||||
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
# LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
|
||||
torch.nn.init.normal_(self.hada_t1, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_t2, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w1_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
|
||||
elif self.conv_mode == "flat":
|
||||
# Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
|
||||
k_prod = 1
|
||||
for k in kernel_size:
|
||||
k_prod *= k
|
||||
flat_in = in_dim * k_prod
|
||||
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
||||
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
else:
|
||||
# Linear or Conv2d 1x1
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta.
|
||||
|
||||
Returns:
|
||||
- Linear: 2D tensor (out_dim, in_dim)
|
||||
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
||||
- Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
"""
|
||||
if self.tucker:
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
|
||||
return HadaWeightTucker.apply(
|
||||
self.hada_t1, self.hada_w1_b, self.hada_w1_a,
|
||||
self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
|
||||
)
|
||||
elif self.conv_mode == "flat":
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
else:
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
diff_weight = self.get_diff_weight()
|
||||
|
||||
# rank dropout (applied on output dimension)
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoHaInfModule(LoHaModule):
|
||||
"""LoHa module for inference. Supports merge_to and get_weight."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference; pass use_tucker from kwargs
|
||||
use_tucker = kwargs.pop("use_tucker", False)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
self.network: AdditionalNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float)
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get LoHa weights
|
||||
w1a = sd["hada_w1_a"].to(torch.float).to(device)
|
||||
w1b = sd["hada_w1_b"].to(torch.float).to(device)
|
||||
w2a = sd["hada_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["hada_w2_b"].to(torch.float).to(device)
|
||||
|
||||
if self.tucker:
|
||||
# Tucker mode
|
||||
t1 = sd["hada_t1"].to(torch.float).to(device)
|
||||
t2 = sd["hada_t2"].to(torch.float).to(device)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
diff_weight = rebuild1 * rebuild2 * self.scale
|
||||
else:
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
|
||||
# reshape diff_weight to match original weight shape if needed
|
||||
if diff_weight.shape != weight.shape:
|
||||
diff_weight = diff_weight.reshape(weight.shape)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
if self.tucker:
|
||||
t1 = self.hada_t1.to(torch.float)
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
t2 = self.hada_t2.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
weight = rebuild1 * rebuild2 * self.scale * multiplier
|
||||
else:
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
|
||||
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
elif self.conv_mode == "flat":
|
||||
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a LoHa network. Called by train_network.py via network_module.create_network()."""
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
||||
|
||||
# exclude patterns
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns from arch config
|
||||
exclude_patterns.extend(arch_config.default_excludes)
|
||||
|
||||
# include patterns
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha for Conv2d 3x3
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
conv_lora_dim = int(conv_lora_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# Tucker decomposition for Conv2d 3x3
|
||||
use_tucker = kwargs.get("use_tucker", "false")
|
||||
if use_tucker is not None:
|
||||
use_tucker = True if str(use_tucker).lower() == "true" else False
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if str(verbose).lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoHaModule,
|
||||
module_kwargs={"use_tucker": use_tucker},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# LoRA+ support
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
"""Create a LoHa network from saved weights. Called by train_network.py."""
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# detect dim/alpha from weights
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "hada_w1_b" in key:
|
||||
dim = value.shape[0]
|
||||
modules_dim[lora_name] = dim
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# detect Tucker mode from weights
|
||||
use_tucker = any("hada_t1" in key for key in weights_sd.keys())
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
module_class = LoHaInfModule if for_inference else LoHaModule
|
||||
module_kwargs = {"use_tucker": use_tucker}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
module_kwargs=module_kwargs,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
def merge_weights_to_tensor(
|
||||
model_weight: torch.Tensor,
|
||||
lora_name: str,
|
||||
lora_sd: Dict[str, torch.Tensor],
|
||||
lora_weight_keys: set,
|
||||
multiplier: float,
|
||||
calc_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoHa weights directly into a model weight tensor.
|
||||
|
||||
Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoHa keys found.
|
||||
"""
|
||||
w1a_key = lora_name + ".hada_w1_a"
|
||||
w1b_key = lora_name + ".hada_w1_b"
|
||||
w2a_key = lora_name + ".hada_w2_a"
|
||||
w2b_key = lora_name + ".hada_w2_b"
|
||||
t1_key = lora_name + ".hada_t1"
|
||||
t2_key = lora_name + ".hada_t2"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1a_key not in lora_weight_keys:
|
||||
return model_weight
|
||||
|
||||
w1a = lora_sd[w1a_key].to(calc_device)
|
||||
w1b = lora_sd[w1b_key].to(calc_device)
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
|
||||
has_tucker = t1_key in lora_weight_keys
|
||||
|
||||
dim = w1b.shape[0]
|
||||
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.item()
|
||||
scale = alpha / dim
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
|
||||
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
if has_tucker:
|
||||
# Tucker decomposition: rebuild via einsum
|
||||
t1 = lora_sd[t1_key].to(calc_device)
|
||||
t2 = lora_sd[t2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
diff_weight = rebuild1 * rebuild2 * scale
|
||||
else:
|
||||
# Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
|
||||
# Reshape diff_weight to match model_weight shape if needed
|
||||
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
||||
if diff_weight.shape != model_weight.shape:
|
||||
diff_weight = diff_weight.reshape(model_weight.shape)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
if original_dtype.itemsize == 1:
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
|
||||
if has_tucker:
|
||||
consumed.extend([t1_key, t2_key])
|
||||
for key in consumed:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
683
networks/lokr.py
Normal file
683
networks/lokr.py
Normal file
@@ -0,0 +1,683 @@
|
||||
# LoKr (Low-rank Kronecker Product) network module
|
||||
# Reference: https://arxiv.org/abs/2309.14859
|
||||
#
|
||||
# Based on the LyCORIS project by KohakuBlueleaf
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS
|
||||
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple:
|
||||
"""Return a tuple of two values whose product equals dimension,
|
||||
optimized for balanced factors.
|
||||
|
||||
In LoKr, the first value is for the weight scale (smaller),
|
||||
and the second value is for the weight (larger).
|
||||
|
||||
Examples:
|
||||
factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32)
|
||||
factor=4: 128 -> (4, 32), 512 -> (4, 128)
|
||||
"""
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
|
||||
def make_kron(w1, w2, scale):
|
||||
"""Compute Kronecker product of w1 and w2, scaled by scale."""
|
||||
if w1.dim() != w2.dim():
|
||||
for _ in range(w2.dim() - w1.dim()):
|
||||
w1 = w1.unsqueeze(-1)
|
||||
w2 = w2.contiguous()
|
||||
rebuild = torch.kron(w1, w2)
|
||||
if scale != 1:
|
||||
rebuild = rebuild * scale
|
||||
return rebuild
|
||||
|
||||
|
||||
def rebuild_tucker(t, wa, wb):
|
||||
"""Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb).
|
||||
|
||||
Compatible with LyCORIS convention.
|
||||
"""
|
||||
return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb)
|
||||
|
||||
|
||||
class LoKrModule(torch.nn.Module):
|
||||
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
factor=-1,
|
||||
use_tucker=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
kernel_size = org_module.kernel_size
|
||||
self.is_conv = True
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.dilation = org_module.dilation
|
||||
self.groups = org_module.groups
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
||||
|
||||
if kernel_size == (1, 1):
|
||||
self.conv_mode = "1x1"
|
||||
elif self.tucker:
|
||||
self.conv_mode = "tucker"
|
||||
else:
|
||||
self.conv_mode = "flat"
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
self.tucker = False
|
||||
self.conv_mode = None
|
||||
self.kernel_size = None
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
factor = int(factor)
|
||||
self.use_w2 = False
|
||||
|
||||
# Factorize dimensions
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
|
||||
# w1 is always a full matrix (the "scale" factor, small)
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
|
||||
|
||||
# w2: depends on mode
|
||||
if self.conv_mode in ("tucker", "flat"):
|
||||
# Conv2d 3x3+ modes
|
||||
k_size = kernel_size
|
||||
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
# Full matrix mode (includes kernel dimensions)
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size))
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode for Conv2d."
|
||||
)
|
||||
elif self.tucker:
|
||||
# Tucker mode: separate kernel into t2 tensor
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size))
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
# Non-Tucker: flatten kernel into w2_b
|
||||
k_prod = 1
|
||||
for k in k_size:
|
||||
k_prod *= k
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod))
|
||||
else:
|
||||
# Linear or Conv2d 1x1
|
||||
if lora_dim < max(out_k, in_n) / 2:
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode."
|
||||
)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
# if both w1 and w2 are full matrices, use scale = 1
|
||||
if self.use_w2:
|
||||
alpha = lora_dim
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
# Initialization
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
else:
|
||||
if self.tucker:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
# Ensures ΔW = kron(w1, 0) = 0 at init
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta.
|
||||
|
||||
Returns:
|
||||
- Linear: 2D tensor (out_dim, in_dim)
|
||||
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
||||
- Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D
|
||||
"""
|
||||
w1 = self.lokr_w1
|
||||
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2
|
||||
elif self.tucker:
|
||||
w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
|
||||
result = make_kron(w1, w2, self.scale)
|
||||
|
||||
# For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D
|
||||
if self.conv_mode == "flat" and result.dim() == 2:
|
||||
result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
diff_weight = self.get_diff_weight()
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoKrInfModule(LoKrModule):
|
||||
"""LoKr module for inference. Supports merge_to and get_weight."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference; pass factor and use_tucker from kwargs
|
||||
factor = kwargs.pop("factor", -1)
|
||||
use_tucker = kwargs.pop("use_tucker", False)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
self.network: AdditionalNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float)
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get LoKr weights
|
||||
w1 = sd["lokr_w1"].to(torch.float).to(device)
|
||||
|
||||
if "lokr_w2" in sd:
|
||||
w2 = sd["lokr_w2"].to(torch.float).to(device)
|
||||
elif "lokr_t2" in sd:
|
||||
# Tucker mode
|
||||
t2 = sd["lokr_t2"].to(torch.float).to(device)
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
w2 = rebuild_tucker(t2, w2a, w2b)
|
||||
else:
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
w2 = w2a @ w2b
|
||||
|
||||
# compute ΔW via Kronecker product
|
||||
diff_weight = make_kron(w1, w2, self.scale)
|
||||
|
||||
# reshape diff_weight to match original weight shape if needed
|
||||
if diff_weight.shape != weight.shape:
|
||||
diff_weight = diff_weight.reshape(weight.shape)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
w1 = self.lokr_w1.to(torch.float)
|
||||
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2.to(torch.float)
|
||||
elif self.tucker:
|
||||
w2 = rebuild_tucker(
|
||||
self.lokr_t2.to(torch.float),
|
||||
self.lokr_w2_a.to(torch.float),
|
||||
self.lokr_w2_b.to(torch.float),
|
||||
)
|
||||
else:
|
||||
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
|
||||
|
||||
weight = make_kron(w1, w2, self.scale) * multiplier
|
||||
|
||||
# reshape to match original weight shape if needed
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
elif self.conv_mode == "flat" and weight.dim() == 2:
|
||||
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
# Tucker and full matrix modes: already 4D from kron
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a LoKr network. Called by train_network.py via network_module.create_network()."""
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
||||
|
||||
# exclude patterns
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns from arch config
|
||||
exclude_patterns.extend(arch_config.default_excludes)
|
||||
|
||||
# include patterns
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha for Conv2d 3x3
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
conv_lora_dim = int(conv_lora_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# Tucker decomposition for Conv2d 3x3
|
||||
use_tucker = kwargs.get("use_tucker", "false")
|
||||
if use_tucker is not None:
|
||||
use_tucker = True if str(use_tucker).lower() == "true" else False
|
||||
|
||||
# factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if str(verbose).lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoKrModule,
|
||||
module_kwargs={"factor": factor, "use_tucker": use_tucker},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# LoRA+ support
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
"""Create a LoKr network from saved weights. Called by train_network.py."""
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# detect dim/alpha from weights
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
use_tucker = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lokr_w2_a" in key:
|
||||
# low-rank mode: dim detection depends on Tucker vs non-Tucker
|
||||
if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd:
|
||||
# Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0]
|
||||
dim = value.shape[0]
|
||||
else:
|
||||
# Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1]
|
||||
dim = value.shape[1]
|
||||
modules_dim[lora_name] = dim
|
||||
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
|
||||
# full matrix mode: set dim large enough to trigger full-matrix path
|
||||
if lora_name not in modules_dim:
|
||||
modules_dim[lora_name] = max(value.shape[0], value.shape[1])
|
||||
|
||||
if "lokr_t2" in key:
|
||||
use_tucker = True
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# extract factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
module_class = LoKrInfModule if for_inference else LoKrModule
|
||||
module_kwargs = {"factor": factor, "use_tucker": use_tucker}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
module_kwargs=module_kwargs,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
def merge_weights_to_tensor(
|
||||
model_weight: torch.Tensor,
|
||||
lora_name: str,
|
||||
lora_sd: Dict[str, torch.Tensor],
|
||||
lora_weight_keys: set,
|
||||
multiplier: float,
|
||||
calc_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoKr weights directly into a model weight tensor.
|
||||
|
||||
Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoKr keys found.
|
||||
"""
|
||||
w1_key = lora_name + ".lokr_w1"
|
||||
w2_key = lora_name + ".lokr_w2"
|
||||
w2a_key = lora_name + ".lokr_w2_a"
|
||||
w2b_key = lora_name + ".lokr_w2_b"
|
||||
t2_key = lora_name + ".lokr_t2"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1_key not in lora_weight_keys:
|
||||
return model_weight
|
||||
|
||||
w1 = lora_sd[w1_key].to(calc_device)
|
||||
|
||||
# determine mode: full matrix vs Tucker vs low-rank
|
||||
has_tucker = t2_key in lora_weight_keys
|
||||
|
||||
if w2a_key in lora_weight_keys:
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
|
||||
if has_tucker:
|
||||
# Tucker: w2a = (rank, out_k), dim = rank
|
||||
dim = w2a.shape[0]
|
||||
else:
|
||||
# Non-Tucker low-rank: w2a = (out_k, rank), dim = rank
|
||||
dim = w2a.shape[1]
|
||||
|
||||
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
|
||||
if has_tucker:
|
||||
consumed_keys.append(t2_key)
|
||||
elif w2_key in lora_weight_keys:
|
||||
# full matrix mode
|
||||
w2a = None
|
||||
w2b = None
|
||||
dim = None
|
||||
consumed_keys = [w1_key, w2_key, alpha_key]
|
||||
else:
|
||||
return model_weight
|
||||
|
||||
alpha = lora_sd.get(alpha_key, None)
|
||||
if alpha is not None and isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.item()
|
||||
|
||||
# compute scale
|
||||
if w2a is not None:
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
scale = alpha / dim
|
||||
else:
|
||||
# full matrix mode: scale = 1.0
|
||||
scale = 1.0
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1 = w1.to(torch.float16)
|
||||
if w2a is not None:
|
||||
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
# compute w2
|
||||
if w2a is not None:
|
||||
if has_tucker:
|
||||
t2 = lora_sd[t2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
t2 = t2.to(torch.float16)
|
||||
w2 = rebuild_tucker(t2, w2a, w2b)
|
||||
else:
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
w2 = lora_sd[w2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
w2 = w2.to(torch.float16)
|
||||
|
||||
# ΔW = kron(w1, w2) * scale
|
||||
diff_weight = make_kron(w1, w2, scale)
|
||||
|
||||
# Reshape diff_weight to match model_weight shape if needed
|
||||
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
||||
if diff_weight.shape != model_weight.shape:
|
||||
diff_weight = diff_weight.reshape(model_weight.shape)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
if original_dtype.itemsize == 1:
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
for key in consumed_keys:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
846
networks/lora_anima.py
Normal file
846
networks/lora_anima.py
Normal file
@@ -0,0 +1,846 @@
|
||||
# LoRA network module for Anima
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
"""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if isinstance(self.lora_down, torch.nn.Conv2d):
|
||||
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
|
||||
for _ in range(len(lx.size()) - 2):
|
||||
mask = mask.unsqueeze(1)
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
self.org_module_ref = [org_module] # 後から参照できるように
|
||||
self.enabled = True
|
||||
self.network: LoRANetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
# freezeしてマージする
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float) # calc in float
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get up/down weight
|
||||
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
# 復元できるマージのため、このモジュールのweightを返す
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
|
||||
# pre-calculated weight
|
||||
if len(down_weight.size()) == 2:
|
||||
# linear
|
||||
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = self.multiplier * conved * self.scale
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||
lx = self.lora_down(x)
|
||||
lx = self.lora_up(lx)
|
||||
return self.org_forward(x) + lx * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
|
||||
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns
|
||||
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
|
||||
|
||||
# regular expression for module selection: exclude and include
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if verbose.lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
|
||||
"""
|
||||
Parse a string of key-value pairs separated by commas.
|
||||
"""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
if network_reg_lrs is not None:
|
||||
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
|
||||
else:
|
||||
reg_lrs = None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
if network_reg_dims is not None:
|
||||
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
|
||||
else:
|
||||
reg_dims = None
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
|
||||
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
|
||||
# Target modules: LLM Adapter blocks
|
||||
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
|
||||
# Target modules for text encoder (Qwen3)
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
|
||||
|
||||
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info("create LoRA network from weights")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expression if specified
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
is_unet: bool,
|
||||
text_encoder_idx: Optional[int],
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
default_dim: Optional[int] = None,
|
||||
) -> Tuple[List[LoRAModule], List[str]]:
|
||||
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None:
|
||||
module = root_module
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
# exclude/include filter (fullmatch: pattern must match the entire original_name)
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
else:
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha_val,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
)
|
||||
lora.original_name = original_name
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break
|
||||
return loras, skipped
|
||||
|
||||
# Create LoRA for text encoders (Qwen3 - typically not trained for Anima)
|
||||
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||
skipped_te = []
|
||||
if text_encoders is not None:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}:")
|
||||
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
|
||||
# Create LoRA for DiT blocks
|
||||
target_modules = list(LoRANetwork.ANIMA_TARGET_REPLACE_MODULE)
|
||||
if train_llm_adapter:
|
||||
target_modules.extend(LoRANetwork.ANIMA_ADAPTER_TARGET_REPLACE_MODULE)
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
|
||||
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(f"dim (rank) is 0, {len(skipped)} LoRA modules are skipped:")
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
# assertion: no duplicate names
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info(f"enable LoRA for DiT: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_ANIMA):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for DiT")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||
text_encoder_lr = [default_lr]
|
||||
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||
text_encoder_lr = [float(text_encoder_lr)]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
pass # already a list with one element
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
reg_groups = {}
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if matched_reg_lr is not None:
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
continue
|
||||
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for group_key, group in reg_groups.items():
|
||||
reg_lr = group["lr"]
|
||||
for key in ("lora", "plus"):
|
||||
param_data = {"params": group[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if key == "plus":
|
||||
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
descriptions.append(desc + (" plus" if key == "plus" else ""))
|
||||
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 1" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
pass # not supported
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
@@ -141,10 +141,13 @@ class LoRAModule(torch.nn.Module):
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
if isinstance(self.lora_down, torch.nn.Conv2d):
|
||||
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
|
||||
for _ in range(len(lx.size()) - 2):
|
||||
mask = mask.unsqueeze(1)
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
@@ -1445,4 +1448,4 @@ class LoRANetwork(torch.nn.Module):
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
else:
|
||||
# split_dims
|
||||
total_dims = sum(self.split_dims)
|
||||
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
|
||||
for i in range(len(self.split_dims)):
|
||||
# get up/down weight
|
||||
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
|
||||
|
||||
# pad up_weight -> (total_dims, rank)
|
||||
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
||||
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
||||
|
||||
# merge weight
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
# merge into the correct slice of the fused weight
|
||||
start = sum(self.split_dims[:i])
|
||||
end = sum(self.split_dims[:i + 1])
|
||||
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# Handle split_dims case where lora_down/lora_up are ModuleList
|
||||
if self.split_dims is not None:
|
||||
# Each sub-module produces a partial weight; concatenate along output dim
|
||||
weights = []
|
||||
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
|
||||
up_w = lora_up.weight.to(torch.float)
|
||||
down_w = lora_down.weight.to(torch.float)
|
||||
weights.append(up_w @ down_w)
|
||||
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
|
||||
return weight
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
|
||||
# create LoRA for U-Net
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# Filter by block type using name-based filtering in create_modules
|
||||
# All block types use JointTransformerBlock, so we filter by module path name
|
||||
block_filter = None # None means no filtering (train all)
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# TODO: limit different blocks
|
||||
block_filter = None
|
||||
elif self.train_blocks == "transformer":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "refiners":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
|
||||
elif self.train_blocks == "noise_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "cap_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "noise_refiner"
|
||||
elif self.train_blocks == "context_refiner":
|
||||
block_filter = "context_refiner"
|
||||
elif self.train_blocks == "refiners":
|
||||
block_filter = None # handled below with two calls
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
|
||||
if self.train_blocks == "refiners":
|
||||
# Refiners = noise_refiner + context_refiner, need two calls
|
||||
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
|
||||
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
|
||||
self.unet_loras = noise_loras + context_loras
|
||||
skipped_un = skipped_noise + skipped_context
|
||||
else:
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
|
||||
|
||||
# Handle embedders
|
||||
if self.embedder_dims:
|
||||
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
if "qkv" in key:
|
||||
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
|
||||
# Q=24*96=2304, K=8*96=768, V=8*96=768
|
||||
split_dims = [2304, 768, 768]
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
@@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module):
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
545
networks/network_base.py
Normal file
545
networks/network_base.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc).
|
||||
# Provides architecture detection and a generic AdditionalNetwork class.
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchConfig:
|
||||
unet_target_modules: List[str]
|
||||
te_target_modules: List[str]
|
||||
unet_prefix: str
|
||||
te_prefixes: List[str]
|
||||
default_excludes: List[str] = field(default_factory=list)
|
||||
adapter_target_modules: List[str] = field(default_factory=list)
|
||||
unet_conv_target_modules: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
||||
"""Detect architecture from model structure and return ArchConfig."""
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
# Check SDXL first
|
||||
if unet is not None and (
|
||||
issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
|
||||
):
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Transformer2DModel"],
|
||||
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te1", "lora_te2"],
|
||||
default_excludes=[],
|
||||
unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"],
|
||||
)
|
||||
|
||||
# Check Anima: look for Block class in named_modules
|
||||
module_class_names = set()
|
||||
if unet is not None:
|
||||
for module in unet.modules():
|
||||
module_class_names.add(type(module).__name__)
|
||||
|
||||
if "Block" in module_class_names:
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"],
|
||||
te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"],
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te"],
|
||||
default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"],
|
||||
adapter_target_modules=["LLMAdapterTransformerBlock"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}")
|
||||
|
||||
|
||||
def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]:
|
||||
"""Parse a string of key-value pairs separated by commas."""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
|
||||
class AdditionalNetwork(torch.nn.Module):
|
||||
"""Generic Additional network that supports LoHa, LoKr, and similar module types.
|
||||
|
||||
Constructed with a module_class parameter to inject the specific module type.
|
||||
Based on the lora_anima.py LoRANetwork, generalized for multiple architectures.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
arch_config: ArchConfig,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
module_class: Type[torch.nn.Module] = None,
|
||||
module_kwargs: Optional[Dict] = None,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert module_class is not None, "module_class must be specified"
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
self.arch_config = arch_config
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if module_kwargs is None:
|
||||
module_kwargs = {}
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create {module_class.__name__} network from weights")
|
||||
else:
|
||||
logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expressions
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
prefix: str,
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
default_dim: Optional[int] = None,
|
||||
) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None:
|
||||
module = root_module
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
# exclude/include filter
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
else:
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
# fallback to default dim
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
elif is_conv2d and self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha_val = self.conv_alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha_val,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
**module_kwargs,
|
||||
)
|
||||
lora.original_name = original_name
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break
|
||||
return loras, skipped
|
||||
|
||||
# Create modules for text encoders
|
||||
self.text_encoder_loras: List[torch.nn.Module] = []
|
||||
skipped_te = []
|
||||
if text_encoders is not None:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
|
||||
# Determine prefix for this text encoder
|
||||
if i < len(arch_config.te_prefixes):
|
||||
te_prefix = arch_config.te_prefixes[i]
|
||||
else:
|
||||
te_prefix = arch_config.te_prefixes[0]
|
||||
|
||||
logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):")
|
||||
te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules)
|
||||
logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
|
||||
# Create modules for UNet/DiT
|
||||
target_modules = list(arch_config.unet_target_modules)
|
||||
if modules_dim is not None or conv_lora_dim is not None:
|
||||
target_modules.extend(arch_config.unet_conv_target_modules)
|
||||
if train_llm_adapter and arch_config.adapter_target_modules:
|
||||
target_modules.extend(arch_config.adapter_target_modules)
|
||||
|
||||
self.unet_loras: List[torch.nn.Module]
|
||||
self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules)
|
||||
logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.")
|
||||
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:")
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
# assertion: no duplicate names
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
|
||||
apply_text_encoder = apply_unet = False
|
||||
te_prefixes = self.arch_config.te_prefixes
|
||||
unet_prefix = self.arch_config.unet_prefix
|
||||
|
||||
for key in weights_sd.keys():
|
||||
if any(key.startswith(p) for p in te_prefixes):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(unet_prefix):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable modules for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable modules for UNet/DiT")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||
text_encoder_lr = [default_lr]
|
||||
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||
text_encoder_lr = [float(text_encoder_lr)]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
pass # already a list with one element
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
reg_groups = {}
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if matched_reg_lr is not None:
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
# LoRA+ detection: check for "up" weight parameters
|
||||
if loraplus_ratio is not None and self._is_plus_param(name):
|
||||
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
continue
|
||||
|
||||
if loraplus_ratio is not None and self._is_plus_param(name):
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for group_key, group in reg_groups.items():
|
||||
reg_lr = group["lr"]
|
||||
for key in ("lora", "plus"):
|
||||
param_data = {"params": group[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if key == "plus":
|
||||
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
descriptions.append(desc + (" plus" if key == "plus" else ""))
|
||||
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
# Group TE loras by prefix
|
||||
for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes):
|
||||
te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)]
|
||||
if len(te_loras) > 0:
|
||||
te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0]
|
||||
logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}")
|
||||
params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def _is_plus_param(self, name: str) -> bool:
|
||||
"""Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+.
|
||||
|
||||
For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor).
|
||||
Override in subclass if needed. Default: check for common 'up' patterns.
|
||||
"""
|
||||
return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
pass # not supported
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
@@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata):
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(cumulative_sums, target))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -69,8 +69,8 @@ def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2))
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -78,16 +78,23 @@ def index_sv_fro(S, target):
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
index = int(torch.sum(S > min_sv).item()) - 1
|
||||
index = max(0, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
weight = weight.reshape(out_size, -1)
|
||||
_in_size = in_size * kernel_size * kernel_size
|
||||
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -103,10 +110,14 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -198,10 +209,9 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
if dynamic_method:
|
||||
@@ -262,10 +272,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
@@ -274,15 +284,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
verbose_str += f"{block_down_name:75} | "
|
||||
verbose_str = f"{block_down_name:75} | "
|
||||
verbose_str += (
|
||||
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
)
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
if dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}"
|
||||
tqdm.write(verbose_str)
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
@@ -297,7 +305,6 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, max_old_rank, new_alpha
|
||||
@@ -336,7 +343,7 @@ def resize(args):
|
||||
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
|
||||
)
|
||||
|
||||
# update metadata
|
||||
@@ -414,6 +421,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
parser.add_argument(
|
||||
"--svd_lowrank_niter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
|
||||
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ einops==0.7.0
|
||||
bitsandbytes
|
||||
lion-pytorch==0.2.3
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.9.0
|
||||
pytorch-optimizer==3.10.0
|
||||
prodigy-plus-schedule-free==1.9.2
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
|
||||
@@ -15,6 +15,12 @@ import random
|
||||
import re
|
||||
|
||||
import diffusers
|
||||
|
||||
# Compatible import for diffusers old/new UNet path
|
||||
try:
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
except ImportError:
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
"""
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
logger.info("Enable memory efficient attention for U-Net")
|
||||
|
||||
|
||||
342
sdxl_train_leco.py
Normal file
342
sdxl_train_leco.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
batch_add_time_ids,
|
||||
build_network_kwargs,
|
||||
concat_embeddings_xl,
|
||||
diffusion_xl,
|
||||
encode_prompt_sdxl,
|
||||
get_add_time_ids,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
load_prompt_settings,
|
||||
predict_noise_xl,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
|
||||
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
|
||||
)
|
||||
del vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoders,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
add_time_ids = get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=setting.dynamic_crops,
|
||||
dtype=weight_dtype,
|
||||
device=accelerator.device,
|
||||
)
|
||||
batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
116
tests/library/test_leco_train_util.py
Normal file
116
tests/library/test_leco_train_util.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from library.leco_train_util import load_prompt_settings
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_original_format(tmp_path: Path):
|
||||
prompt_file = tmp_path / "prompts.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
[[prompts]]
|
||||
target = "van gogh"
|
||||
guidance_scale = 1.5
|
||||
resolution = 512
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0].target == "van gogh"
|
||||
assert prompts[0].positive == "van gogh"
|
||||
assert prompts[0].unconditional == ""
|
||||
assert prompts[0].neutral == ""
|
||||
assert prompts[0].action == "erase"
|
||||
assert prompts[0].guidance_scale == 1.5
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
|
||||
prompt_file = tmp_path / "slider.toml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
guidance_scale = 2.0
|
||||
resolution = 768
|
||||
neutral = ""
|
||||
|
||||
[[targets]]
|
||||
target_class = ""
|
||||
positive = "high detail"
|
||||
negative = "low detail"
|
||||
multiplier = 1.25
|
||||
weight = 0.5
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 4
|
||||
|
||||
first = prompts[0]
|
||||
second = prompts[1]
|
||||
third = prompts[2]
|
||||
fourth = prompts[3]
|
||||
|
||||
assert first.target == ""
|
||||
assert first.positive == "low detail"
|
||||
assert first.unconditional == "high detail"
|
||||
assert first.action == "erase"
|
||||
assert first.multiplier == 1.25
|
||||
assert first.weight == 0.5
|
||||
assert first.get_resolution() == (768, 768)
|
||||
|
||||
assert second.positive == "high detail"
|
||||
assert second.unconditional == "low detail"
|
||||
assert second.action == "enhance"
|
||||
assert second.multiplier == 1.25
|
||||
|
||||
assert third.action == "erase"
|
||||
assert third.multiplier == -1.25
|
||||
|
||||
assert fourth.action == "enhance"
|
||||
assert fourth.multiplier == -1.25
|
||||
|
||||
|
||||
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
|
||||
from library import sdxl_train_util
|
||||
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
|
||||
|
||||
class DummyScheduler:
|
||||
def scale_model_input(self, latent_model_input, timestep):
|
||||
return latent_model_input
|
||||
|
||||
class DummyUNet:
|
||||
def __call__(self, x, timesteps, context, y):
|
||||
self.x = x
|
||||
self.timesteps = timesteps
|
||||
self.context = context
|
||||
self.y = y
|
||||
return torch.zeros_like(x)
|
||||
|
||||
latents = torch.randn(1, 4, 8, 8)
|
||||
prompt_embeds = PromptEmbedsXL(
|
||||
text_embeds=torch.randn(2, 77, 2048),
|
||||
pooled_embeds=torch.randn(2, 1280),
|
||||
)
|
||||
add_time_ids = torch.tensor(
|
||||
[
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
],
|
||||
dtype=prompt_embeds.pooled_embeds.dtype,
|
||||
)
|
||||
|
||||
unet = DummyUNet()
|
||||
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
|
||||
|
||||
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
|
||||
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
|
||||
).to(prompt_embeds.pooled_embeds.dtype)
|
||||
|
||||
assert noise_pred.shape == latents.shape
|
||||
assert unet.context is prompt_embeds.text_embeds
|
||||
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))
|
||||
607
tests/manual_test_anima_cache.py
Normal file
607
tests/manual_test_anima_cache.py
Normal file
@@ -0,0 +1,607 @@
|
||||
"""
|
||||
Diagnostic script to test Anima latent & text encoder caching independently.
|
||||
|
||||
Usage:
|
||||
python manual_test_anima_cache.py \
|
||||
--image_dir /path/to/images \
|
||||
--qwen3_path /path/to/qwen3 \
|
||||
--vae_path /path/to/vae.safetensors \
|
||||
[--t5_tokenizer_path /path/to/t5] \
|
||||
[--cache_to_disk]
|
||||
|
||||
The image_dir should contain pairs of:
|
||||
image1.png + image1.txt
|
||||
image2.jpg + image2.txt
|
||||
...
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
# Helpers
|
||||
|
||||
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(), # [0,1]
|
||||
transforms.Normalize([0.5], [0.5]), # [-1,1]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def find_image_caption_pairs(image_dir: str):
|
||||
"""Find (image_path, caption_text) pairs from a directory."""
|
||||
pairs = []
|
||||
for f in sorted(os.listdir(image_dir)):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext not in IMAGE_EXTENSIONS:
|
||||
continue
|
||||
img_path = os.path.join(image_dir, f)
|
||||
txt_path = os.path.splitext(img_path)[0] + ".txt"
|
||||
if os.path.exists(txt_path):
|
||||
with open(txt_path, "r", encoding="utf-8") as fh:
|
||||
caption = fh.read().strip()
|
||||
else:
|
||||
caption = ""
|
||||
pairs.append((img_path, caption))
|
||||
return pairs
|
||||
|
||||
|
||||
def print_tensor_info(name: str, t, indent=2):
|
||||
prefix = " " * indent
|
||||
if t is None:
|
||||
print(f"{prefix}{name}: None")
|
||||
return
|
||||
if isinstance(t, np.ndarray):
|
||||
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
|
||||
elif isinstance(t, torch.Tensor):
|
||||
print(
|
||||
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
|
||||
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
|
||||
)
|
||||
else:
|
||||
print(f"{prefix}{name}: type={type(t)} value={t}")
|
||||
|
||||
|
||||
# Test 1: Latent Cache
|
||||
|
||||
|
||||
def test_latent_cache(args, pairs):
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
|
||||
print("=" * 70)
|
||||
|
||||
from library import qwen_image_autoencoder_kl
|
||||
|
||||
# Load VAE
|
||||
print("\n[1.1] Loading VAE...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
vae_dtype = torch.float32
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||
print(f" VAE loaded on {device}, dtype={vae_dtype}")
|
||||
|
||||
for img_path, caption in pairs:
|
||||
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
|
||||
|
||||
# Load image
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_np = np.array(img)
|
||||
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
|
||||
|
||||
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
|
||||
img_tensor = IMAGE_TRANSFORMS(img_np)
|
||||
print(
|
||||
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
|
||||
)
|
||||
|
||||
# Check range is [-1, 1]
|
||||
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
|
||||
print(" ** WARNING: tensor out of [-1, 1] range!")
|
||||
else:
|
||||
print(" OK: tensor in [-1, 1] range")
|
||||
|
||||
# Encode with VAE
|
||||
img_batch = img_tensor.unsqueeze(0).to(device, dtype=vae_dtype) # (1, C, H, W)
|
||||
img_5d = img_batch.unsqueeze(2) # (1, C, 1, H, W) - add temporal dim
|
||||
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode_pixels_to_latents(img_5d)
|
||||
latents_cpu = latents.cpu()
|
||||
print_tensor_info("Encoded latents", latents_cpu)
|
||||
|
||||
# Check for NaN/Inf
|
||||
if torch.any(torch.isnan(latents_cpu)):
|
||||
print(" ** ERROR: NaN in latents!")
|
||||
elif torch.any(torch.isinf(latents_cpu)):
|
||||
print(" ** ERROR: Inf in latents!")
|
||||
else:
|
||||
print(" OK: no NaN/Inf")
|
||||
|
||||
# Test disk cache round-trip
|
||||
if args.cache_to_disk:
|
||||
npz_path = os.path.splitext(img_path)[0] + "_test_latent.npz"
|
||||
latents_np = latents_cpu.float().numpy()
|
||||
h, w = img_np.shape[:2]
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_np,
|
||||
original_size=np.array([w, h]),
|
||||
crop_ltrb=np.array([0, 0, 0, 0]),
|
||||
)
|
||||
print(f" Saved to: {npz_path}")
|
||||
|
||||
# Reload
|
||||
loaded = np.load(npz_path)
|
||||
loaded_latents = loaded["latents"]
|
||||
print_tensor_info("Reloaded latents", loaded_latents)
|
||||
|
||||
# Compare
|
||||
diff = np.abs(latents_np - loaded_latents).max()
|
||||
print(f" Max diff (save vs load): {diff:.2e}")
|
||||
if diff > 1e-5:
|
||||
print(" ** WARNING: latent cache round-trip has significant diff!")
|
||||
else:
|
||||
print(" OK: round-trip matches")
|
||||
|
||||
os.remove(npz_path)
|
||||
print(f" Cleaned up {npz_path}")
|
||||
|
||||
vae.to("cpu")
|
||||
del vae
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
print("\n[1.3] Latent cache test DONE.")
|
||||
|
||||
|
||||
# Test 2: Text Encoder Output Cache
|
||||
|
||||
|
||||
def test_text_encoder_cache(args, pairs):
|
||||
# TODO Rewrite this
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
|
||||
print("=" * 70)
|
||||
|
||||
from library import anima_utils
|
||||
|
||||
# Load tokenizers
|
||||
print("\n[2.1] Loading tokenizers...")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
|
||||
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
|
||||
|
||||
# Load text encoder
|
||||
print("\n[2.2] Loading Qwen3 text encoder...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||
qwen3_model.eval()
|
||||
|
||||
# Create strategy objects
|
||||
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||
|
||||
tokenize_strategy = AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length,
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy()
|
||||
|
||||
captions = [cap for _, cap in pairs]
|
||||
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
|
||||
for i, cap in enumerate(captions):
|
||||
print(f" [{i}] \"{cap[:80]}{'...' if len(cap) > 80 else ''}\"")
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens_and_masks
|
||||
|
||||
print(f"\n Tokenization results:")
|
||||
print_tensor_info("qwen3_input_ids", qwen3_input_ids)
|
||||
print_tensor_info("qwen3_attn_mask", qwen3_attn_mask)
|
||||
print_tensor_info("t5_input_ids", t5_input_ids)
|
||||
print_tensor_info("t5_attn_mask", t5_attn_mask)
|
||||
|
||||
# Encode
|
||||
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_model], tokens_and_masks
|
||||
)
|
||||
|
||||
print(f" Encoding results:")
|
||||
print_tensor_info("prompt_embeds", prompt_embeds)
|
||||
print_tensor_info("attn_mask", attn_mask)
|
||||
print_tensor_info("t5_input_ids", t5_ids_out)
|
||||
print_tensor_info("t5_attn_mask", t5_mask_out)
|
||||
|
||||
# Check for NaN/Inf
|
||||
if torch.any(torch.isnan(prompt_embeds)):
|
||||
print(" ** ERROR: NaN in prompt_embeds!")
|
||||
elif torch.any(torch.isinf(prompt_embeds)):
|
||||
print(" ** ERROR: Inf in prompt_embeds!")
|
||||
else:
|
||||
print(" OK: no NaN/Inf in prompt_embeds")
|
||||
|
||||
# Test cache round-trip (simulate what AnimaTextEncoderOutputsCachingStrategy does)
|
||||
print(f"\n[2.5] Testing cache round-trip (encode -> numpy -> npz -> reload -> tensor)...")
|
||||
|
||||
# Convert to numpy (same as cache_batch_outputs in strategy_anima.py)
|
||||
pe_cpu = prompt_embeds.cpu()
|
||||
if pe_cpu.dtype == torch.bfloat16:
|
||||
pe_cpu = pe_cpu.float()
|
||||
pe_np = pe_cpu.numpy()
|
||||
am_np = attn_mask.cpu().numpy()
|
||||
t5_ids_np = t5_ids_out.cpu().numpy().astype(np.int32)
|
||||
t5_mask_np = t5_mask_out.cpu().numpy().astype(np.int32)
|
||||
|
||||
print(f" Numpy conversions:")
|
||||
print_tensor_info("prompt_embeds_np", pe_np)
|
||||
print_tensor_info("attn_mask_np", am_np)
|
||||
print_tensor_info("t5_input_ids_np", t5_ids_np)
|
||||
print_tensor_info("t5_attn_mask_np", t5_mask_np)
|
||||
|
||||
if args.cache_to_disk:
|
||||
npz_path = os.path.join(args.image_dir, "_test_te_cache.npz")
|
||||
# Save per-sample (simulating cache_batch_outputs)
|
||||
for i in range(len(captions)):
|
||||
sample_npz = os.path.splitext(pairs[i][0])[0] + "_test_te.npz"
|
||||
np.savez(
|
||||
sample_npz,
|
||||
prompt_embeds=pe_np[i],
|
||||
attn_mask=am_np[i],
|
||||
t5_input_ids=t5_ids_np[i],
|
||||
t5_attn_mask=t5_mask_np[i],
|
||||
)
|
||||
print(f" Saved: {sample_npz}")
|
||||
|
||||
# Reload (simulating load_outputs_npz)
|
||||
data = np.load(sample_npz)
|
||||
print(f" Reloaded keys: {list(data.keys())}")
|
||||
print_tensor_info(" loaded prompt_embeds", data["prompt_embeds"], indent=4)
|
||||
print_tensor_info(" loaded attn_mask", data["attn_mask"], indent=4)
|
||||
print_tensor_info(" loaded t5_input_ids", data["t5_input_ids"], indent=4)
|
||||
print_tensor_info(" loaded t5_attn_mask", data["t5_attn_mask"], indent=4)
|
||||
|
||||
# Check diff
|
||||
diff_pe = np.abs(pe_np[i] - data["prompt_embeds"]).max()
|
||||
diff_t5 = np.abs(t5_ids_np[i] - data["t5_input_ids"]).max()
|
||||
print(f" Max diff prompt_embeds: {diff_pe:.2e}")
|
||||
print(f" Max diff t5_input_ids: {diff_t5:.2e}")
|
||||
if diff_pe > 1e-5 or diff_t5 > 0:
|
||||
print(" ** WARNING: cache round-trip mismatch!")
|
||||
else:
|
||||
print(" OK: round-trip matches")
|
||||
|
||||
os.remove(sample_npz)
|
||||
print(f" Cleaned up {sample_npz}")
|
||||
|
||||
# Test in-memory cache round-trip (simulating what __getitem__ does)
|
||||
print(f"\n[2.6] Testing in-memory cache simulation (tuple -> none_or_stack_elements -> batch)...")
|
||||
|
||||
# Simulate per-sample storage (like info.text_encoder_outputs = tuple)
|
||||
per_sample_cached = []
|
||||
for i in range(len(captions)):
|
||||
per_sample_cached.append((pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]))
|
||||
|
||||
# Simulate none_or_stack_elements with torch.FloatTensor converter
|
||||
# This is what train_util.py __getitem__ does at line 1784
|
||||
stacked = []
|
||||
for elem_idx in range(4):
|
||||
arrays = [sample[elem_idx] for sample in per_sample_cached]
|
||||
stacked.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
||||
|
||||
print(f" Stacked batch (like batch['text_encoder_outputs_list']):")
|
||||
names = ["prompt_embeds", "attn_mask", "t5_input_ids", "t5_attn_mask"]
|
||||
for name, tensor in zip(names, stacked):
|
||||
print_tensor_info(name, tensor)
|
||||
|
||||
# Check condition: len(text_encoder_conds) == 0 or text_encoder_conds[0] is None
|
||||
text_encoder_conds = stacked
|
||||
cond_check_1 = len(text_encoder_conds) == 0
|
||||
cond_check_2 = text_encoder_conds[0] is None
|
||||
print(f"\n Condition check (should both be False when caching works):")
|
||||
print(f" len(text_encoder_conds) == 0 : {cond_check_1}")
|
||||
print(f" text_encoder_conds[0] is None: {cond_check_2}")
|
||||
if not cond_check_1 and not cond_check_2:
|
||||
print(" OK: cached text encoder outputs would be used")
|
||||
else:
|
||||
print(" ** BUG: code would try to re-encode (and crash on None input_ids_list)!")
|
||||
|
||||
# Test unpack for get_noise_pred_and_target (line 311)
|
||||
print(f"\n[2.7] Testing unpack: prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds")
|
||||
try:
|
||||
pe_batch, am_batch, t5_ids_batch, t5_mask_batch = text_encoder_conds
|
||||
print(f" Unpack OK")
|
||||
print_tensor_info("prompt_embeds", pe_batch)
|
||||
print_tensor_info("attn_mask", am_batch)
|
||||
print_tensor_info("t5_input_ids", t5_ids_batch)
|
||||
print_tensor_info("t5_attn_mask", t5_mask_batch)
|
||||
|
||||
# Check t5_input_ids are integers (they were converted to FloatTensor!)
|
||||
if t5_ids_batch.dtype != torch.long and t5_ids_batch.dtype != torch.int32:
|
||||
print(f"\n ** NOTE: t5_input_ids dtype is {t5_ids_batch.dtype}, will be cast to long at line 316")
|
||||
t5_ids_long = t5_ids_batch.to(dtype=torch.long)
|
||||
# Check if any precision was lost
|
||||
diff = (t5_ids_batch - t5_ids_long.float()).abs().max()
|
||||
print(f" Float->Long precision loss: {diff:.2e}")
|
||||
if diff > 0.5:
|
||||
print(" ** ERROR: token IDs corrupted by float conversion!")
|
||||
else:
|
||||
print(" OK: float->long conversion is lossless for these IDs")
|
||||
except Exception as e:
|
||||
print(f" ** ERROR unpacking: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Test drop_cached_text_encoder_outputs
|
||||
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
|
||||
dropout_strategy = AnimaTextEncodingStrategy(
|
||||
dropout_rate=0.5, # high rate to ensure some drops
|
||||
)
|
||||
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
|
||||
print(f" Returned {len(dropped)} tensors")
|
||||
for name, tensor in zip(names, dropped):
|
||||
print_tensor_info(f"dropped_{name}", tensor)
|
||||
|
||||
# Check which items were dropped
|
||||
for i in range(len(captions)):
|
||||
is_zero = (dropped[0][i].abs().sum() == 0).item()
|
||||
print(f" Sample {i}: {'DROPPED' if is_zero else 'KEPT'}")
|
||||
|
||||
qwen3_model.to("cpu")
|
||||
del qwen3_model
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
print("\n[2.8] Text encoder cache test DONE.")
|
||||
|
||||
|
||||
# Test 3: Full batch simulation
|
||||
|
||||
|
||||
def test_full_batch_simulation(args, pairs):
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
|
||||
print("=" * 70)
|
||||
|
||||
from library import anima_utils
|
||||
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
vae_dtype = torch.float32
|
||||
|
||||
# Load all models
|
||||
print("\n[3.1] Loading models...")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||
qwen3_model.eval()
|
||||
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||
|
||||
tokenize_strategy = AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length,
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
|
||||
|
||||
captions = [cap for _, cap in pairs]
|
||||
|
||||
# --- Simulate caching phase ---
|
||||
print("\n[3.2] Simulating text encoder caching phase...")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
te_outputs = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_model],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
|
||||
|
||||
# Convert to numpy (same as cache_batch_outputs)
|
||||
pe_np = prompt_embeds.cpu().float().numpy()
|
||||
am_np = attn_mask.cpu().numpy()
|
||||
t5_ids_np = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||
t5_mask_np = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||
|
||||
# Per-sample storage (like info.text_encoder_outputs)
|
||||
per_sample_te = [(pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]) for i in range(len(captions))]
|
||||
|
||||
print(f"\n[3.3] Simulating latent caching phase...")
|
||||
per_sample_latents = []
|
||||
for img_path, _ in pairs:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_np = np.array(img)
|
||||
img_tensor = IMAGE_TRANSFORMS(img_np).unsqueeze(0).unsqueeze(2) # (1,C,1,H,W)
|
||||
img_tensor = img_tensor.to(device, dtype=vae_dtype)
|
||||
with torch.no_grad():
|
||||
lat = vae.encode(img_tensor, vae_scale).cpu()
|
||||
per_sample_latents.append(lat.squeeze(0)) # (C,1,H,W)
|
||||
print(f" {os.path.basename(img_path)}: latent shape={tuple(lat.shape)}")
|
||||
|
||||
# --- Simulate batch construction (__getitem__) ---
|
||||
print(f"\n[3.4] Simulating batch construction...")
|
||||
|
||||
# Use first image's latents only (images may have different resolutions)
|
||||
latents_batch = per_sample_latents[0].unsqueeze(0) # (1,C,1,H,W)
|
||||
print(f" Using first image latent for simulation: shape={tuple(latents_batch.shape)}")
|
||||
|
||||
# Stack text encoder outputs (none_or_stack_elements)
|
||||
text_encoder_outputs_list = []
|
||||
for elem_idx in range(4):
|
||||
arrays = [s[elem_idx] for s in per_sample_te]
|
||||
text_encoder_outputs_list.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
||||
|
||||
# input_ids_list is None when caching
|
||||
input_ids_list = None
|
||||
|
||||
batch = {
|
||||
"latents": latents_batch,
|
||||
"text_encoder_outputs_list": text_encoder_outputs_list,
|
||||
"input_ids_list": input_ids_list,
|
||||
"loss_weights": torch.ones(len(captions)),
|
||||
}
|
||||
|
||||
print(f" batch keys: {list(batch.keys())}")
|
||||
print(f" batch['latents']: shape={tuple(batch['latents'].shape)}")
|
||||
print(f" batch['text_encoder_outputs_list']: {len(batch['text_encoder_outputs_list'])} tensors")
|
||||
print(f" batch['input_ids_list']: {batch['input_ids_list']}")
|
||||
|
||||
# --- Simulate process_batch logic ---
|
||||
print(f"\n[3.5] Simulating process_batch logic...")
|
||||
|
||||
text_encoder_conds = []
|
||||
te_out = batch.get("text_encoder_outputs_list", None)
|
||||
if te_out is not None:
|
||||
text_encoder_conds = te_out
|
||||
print(f" text_encoder_conds loaded from cache: {len(text_encoder_conds)} tensors")
|
||||
else:
|
||||
print(f" text_encoder_conds: empty (no cache)")
|
||||
|
||||
# The critical condition
|
||||
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
|
||||
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
|
||||
|
||||
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
|
||||
cond_new = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_FALSE
|
||||
|
||||
print(f"\n === CRITICAL CONDITION CHECK ===")
|
||||
print(f" len(text_encoder_conds) == 0 : {len(text_encoder_conds) == 0}")
|
||||
print(f" text_encoder_conds[0] is None: {text_encoder_conds[0] is None}")
|
||||
print(f" train_text_encoder (OLD=True) : {train_text_encoder_TRUE}")
|
||||
print(f" train_text_encoder (NEW=False): {train_text_encoder_FALSE}")
|
||||
print(f"")
|
||||
print(f" Condition with OLD behavior (no override): {cond_old}")
|
||||
msg = (
|
||||
"ENTERS re-encode block -> accesses batch['input_ids_list'] -> CRASH!"
|
||||
if cond_old
|
||||
else "SKIPS re-encode block -> uses cache -> OK"
|
||||
)
|
||||
|
||||
print(f" -> {msg}")
|
||||
print(f" Condition with NEW behavior (override): {cond_new}")
|
||||
print(f" -> {'ENTERS re-encode block' if cond_new else 'SKIPS re-encode block -> uses cache -> OK'}")
|
||||
|
||||
if cond_old and not cond_new:
|
||||
print(f"\n ** CONFIRMED: the is_train_text_encoder override fixes the crash **")
|
||||
|
||||
# Simulate the rest of process_batch
|
||||
print(f"\n[3.6] Simulating get_noise_pred_and_target unpack...")
|
||||
try:
|
||||
pe, am, t5_ids, t5_mask = text_encoder_conds
|
||||
pe = pe.to(device, dtype=te_dtype)
|
||||
am = am.to(device)
|
||||
t5_ids = t5_ids.to(device, dtype=torch.long)
|
||||
t5_mask = t5_mask.to(device)
|
||||
|
||||
print(f" Unpack + device transfer OK:")
|
||||
print_tensor_info("prompt_embeds", pe)
|
||||
print_tensor_info("attn_mask", am)
|
||||
print_tensor_info("t5_input_ids", t5_ids)
|
||||
print_tensor_info("t5_attn_mask", t5_mask)
|
||||
|
||||
# Verify t5_input_ids didn't get corrupted by float conversion
|
||||
t5_ids_orig = torch.tensor(t5_ids_np, dtype=torch.long, device=device)
|
||||
id_match = torch.all(t5_ids == t5_ids_orig).item()
|
||||
print(f"\n t5_input_ids integrity (float->long roundtrip): {'OK' if id_match else '** MISMATCH **'}")
|
||||
if not id_match:
|
||||
diff_count = (t5_ids != t5_ids_orig).sum().item()
|
||||
print(f" {diff_count} token IDs differ!")
|
||||
# Show example
|
||||
idx = torch.where(t5_ids != t5_ids_orig)
|
||||
if len(idx[0]) > 0:
|
||||
i, j = idx[0][0].item(), idx[1][0].item()
|
||||
print(f" Example: position [{i},{j}] original={t5_ids_orig[i,j].item()} loaded={t5_ids[i,j].item()}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ** ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Cleanup
|
||||
vae.to("cpu")
|
||||
qwen3_model.to("cpu")
|
||||
del vae, qwen3_model
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
print("\n[3.7] Full batch simulation DONE.")
|
||||
|
||||
|
||||
# Main
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
|
||||
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
|
||||
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
|
||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
|
||||
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
|
||||
parser.add_argument("--qwen3_max_length", type=int, default=512)
|
||||
parser.add_argument("--t5_max_length", type=int, default=512)
|
||||
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
|
||||
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
|
||||
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
|
||||
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find pairs
|
||||
pairs = find_image_caption_pairs(args.image_dir)
|
||||
if len(pairs) == 0:
|
||||
print(f"ERROR: No image+txt pairs found in {args.image_dir}")
|
||||
print("Expected: image.png + image.txt, image.jpg + image.txt, etc.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(pairs)} image-caption pairs:")
|
||||
for img_path, cap in pairs:
|
||||
print(f" {os.path.basename(img_path)}: \"{cap[:60]}{'...' if len(cap) > 60 else ''}\"")
|
||||
|
||||
results = {}
|
||||
|
||||
if not args.skip_latent:
|
||||
try:
|
||||
test_latent_cache(args, pairs)
|
||||
results["latent_cache"] = "PASS"
|
||||
except Exception as e:
|
||||
print(f"\n** LATENT CACHE TEST FAILED: {e}")
|
||||
traceback.print_exc()
|
||||
results["latent_cache"] = f"FAIL: {e}"
|
||||
|
||||
if not args.skip_text:
|
||||
try:
|
||||
test_text_encoder_cache(args, pairs)
|
||||
results["text_encoder_cache"] = "PASS"
|
||||
except Exception as e:
|
||||
print(f"\n** TEXT ENCODER CACHE TEST FAILED: {e}")
|
||||
traceback.print_exc()
|
||||
results["text_encoder_cache"] = f"FAIL: {e}"
|
||||
|
||||
if not args.skip_full:
|
||||
try:
|
||||
test_full_batch_simulation(args, pairs)
|
||||
results["full_batch_sim"] = "PASS"
|
||||
except Exception as e:
|
||||
print(f"\n** FULL BATCH SIMULATION FAILED: {e}")
|
||||
traceback.print_exc()
|
||||
results["full_batch_sim"] = f"FAIL: {e}"
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("SUMMARY")
|
||||
print("=" * 70)
|
||||
for test, result in results.items():
|
||||
status = "OK" if result == "PASS" else "FAIL"
|
||||
print(f" [{status}] {test}: {result}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
242
tests/manual_test_anima_real_training.py
Normal file
242
tests/manual_test_anima_real_training.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Test script that actually runs anima_train.py and anima_train_network.py
|
||||
for a few steps to verify --cache_text_encoder_outputs works.
|
||||
|
||||
Usage:
|
||||
python test_anima_real_training.py \
|
||||
--image_dir /path/to/images_with_txt \
|
||||
--dit_path /path/to/dit.safetensors \
|
||||
--qwen3_path /path/to/qwen3 \
|
||||
--vae_path /path/to/vae.safetensors \
|
||||
[--t5_tokenizer_path /path/to/t5] \
|
||||
[--resolution 512]
|
||||
|
||||
This will run 4 tests:
|
||||
1. anima_train.py (full finetune, no cache)
|
||||
2. anima_train.py (full finetune, --cache_text_encoder_outputs)
|
||||
3. anima_train_network.py (LoRA, no cache)
|
||||
4. anima_train_network.py (LoRA, --cache_text_encoder_outputs)
|
||||
|
||||
Each test runs only 2 training steps then stops.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
|
||||
def create_dataset_toml(image_dir: str, resolution: int, toml_path: str):
|
||||
"""Create a minimal dataset toml config."""
|
||||
content = f"""[general]
|
||||
resolution = {resolution}
|
||||
enable_bucket = true
|
||||
bucket_reso_steps = 8
|
||||
min_bucket_reso = 256
|
||||
max_bucket_reso = 1024
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 1
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "{image_dir}"
|
||||
num_repeats = 1
|
||||
caption_extension = ".txt"
|
||||
"""
|
||||
with open(toml_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return toml_path
|
||||
|
||||
|
||||
def run_test(test_name: str, cmd: list, timeout: int = 300) -> dict:
|
||||
"""Run a training command and capture result."""
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"TEST: {test_name}")
|
||||
print(f"{'=' * 70}")
|
||||
print(f"Command: {' '.join(cmd)}\n")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
)
|
||||
|
||||
stdout = result.stdout
|
||||
stderr = result.stderr
|
||||
returncode = result.returncode
|
||||
|
||||
# Print last N lines of output
|
||||
all_output = stdout + "\n" + stderr
|
||||
lines = all_output.strip().split("\n")
|
||||
print(f"--- Last 30 lines of output ---")
|
||||
for line in lines[-30:]:
|
||||
print(f" {line}")
|
||||
print(f"--- End output ---\n")
|
||||
|
||||
if returncode == 0:
|
||||
print(f"RESULT: PASS (exit code 0)")
|
||||
return {"status": "PASS", "detail": "completed successfully"}
|
||||
else:
|
||||
# Check if it's a known error
|
||||
if "TypeError: 'NoneType' object is not iterable" in all_output:
|
||||
print(f"RESULT: FAIL - input_ids_list is None (the cache_text_encoder_outputs bug)")
|
||||
return {"status": "FAIL", "detail": "input_ids_list is None - cache TE outputs bug"}
|
||||
elif "steps: 0%" in all_output and "Error" in all_output:
|
||||
# Find the actual error
|
||||
error_lines = [l for l in lines if "Error" in l or "Traceback" in l or "raise" in l.lower()]
|
||||
detail = error_lines[-1] if error_lines else f"exit code {returncode}"
|
||||
print(f"RESULT: FAIL - {detail}")
|
||||
return {"status": "FAIL", "detail": detail}
|
||||
else:
|
||||
print(f"RESULT: FAIL (exit code {returncode})")
|
||||
return {"status": "FAIL", "detail": f"exit code {returncode}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"RESULT: TIMEOUT (>{timeout}s)")
|
||||
return {"status": "TIMEOUT", "detail": f"exceeded {timeout}s"}
|
||||
except Exception as e:
|
||||
print(f"RESULT: ERROR - {e}")
|
||||
return {"status": "ERROR", "detail": str(e)}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test Anima real training with cache flags")
|
||||
parser.add_argument("--image_dir", type=str, required=True,
|
||||
help="Directory with image+txt pairs")
|
||||
parser.add_argument("--dit_path", type=str, required=True,
|
||||
help="Path to Anima DiT safetensors")
|
||||
parser.add_argument("--qwen3_path", type=str, required=True,
|
||||
help="Path to Qwen3 model")
|
||||
parser.add_argument("--vae_path", type=str, required=True,
|
||||
help="Path to WanVAE safetensors")
|
||||
parser.add_argument("--t5_tokenizer_path", type=str, default=None)
|
||||
parser.add_argument("--resolution", type=int, default=512)
|
||||
parser.add_argument("--timeout", type=int, default=300,
|
||||
help="Timeout per test in seconds (default: 300)")
|
||||
parser.add_argument("--only", type=str, default=None,
|
||||
choices=["finetune", "lora"],
|
||||
help="Only run finetune or lora tests")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate paths
|
||||
for name, path in [("image_dir", args.image_dir), ("dit_path", args.dit_path),
|
||||
("qwen3_path", args.qwen3_path), ("vae_path", args.vae_path)]:
|
||||
if not os.path.exists(path):
|
||||
print(f"ERROR: {name} does not exist: {path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create temp dir for outputs
|
||||
tmp_dir = tempfile.mkdtemp(prefix="anima_test_")
|
||||
print(f"Temp directory: {tmp_dir}")
|
||||
|
||||
# Create dataset toml
|
||||
toml_path = os.path.join(tmp_dir, "dataset.toml")
|
||||
create_dataset_toml(args.image_dir, args.resolution, toml_path)
|
||||
print(f"Dataset config: {toml_path}")
|
||||
|
||||
output_dir = os.path.join(tmp_dir, "output")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
python = sys.executable
|
||||
|
||||
# Common args for both scripts
|
||||
common_anima_args = [
|
||||
"--dit_path", args.dit_path,
|
||||
"--qwen3_path", args.qwen3_path,
|
||||
"--vae_path", args.vae_path,
|
||||
"--pretrained_model_name_or_path", args.dit_path, # required by base parser
|
||||
"--output_dir", output_dir,
|
||||
"--output_name", "test",
|
||||
"--dataset_config", toml_path,
|
||||
"--max_train_steps", "2",
|
||||
"--learning_rate", "1e-5",
|
||||
"--mixed_precision", "bf16",
|
||||
"--save_every_n_steps", "999", # don't save
|
||||
"--max_data_loader_n_workers", "0", # single process for clarity
|
||||
"--logging_dir", os.path.join(tmp_dir, "logs"),
|
||||
"--cache_latents",
|
||||
]
|
||||
if args.t5_tokenizer_path:
|
||||
common_anima_args += ["--t5_tokenizer_path", args.t5_tokenizer_path]
|
||||
|
||||
results = {}
|
||||
|
||||
# TEST 1: anima_train.py - NO cache_text_encoder_outputs
|
||||
if args.only is None or args.only == "finetune":
|
||||
cmd = [python, "anima_train.py"] + common_anima_args + [
|
||||
"--optimizer_type", "AdamW8bit",
|
||||
]
|
||||
results["finetune_no_cache"] = run_test(
|
||||
"anima_train.py (full finetune, NO text encoder cache)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# TEST 2: anima_train.py - WITH cache_text_encoder_outputs
|
||||
cmd = [python, "anima_train.py"] + common_anima_args + [
|
||||
"--optimizer_type", "AdamW8bit",
|
||||
"--cache_text_encoder_outputs",
|
||||
]
|
||||
results["finetune_with_cache"] = run_test(
|
||||
"anima_train.py (full finetune, WITH --cache_text_encoder_outputs)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# TEST 3: anima_train_network.py - NO cache_text_encoder_outputs
|
||||
if args.only is None or args.only == "lora":
|
||||
lora_args = common_anima_args + [
|
||||
"--optimizer_type", "AdamW8bit",
|
||||
"--network_module", "networks.lora_anima",
|
||||
"--network_dim", "4",
|
||||
"--network_alpha", "1",
|
||||
]
|
||||
|
||||
cmd = [python, "anima_train_network.py"] + lora_args
|
||||
results["lora_no_cache"] = run_test(
|
||||
"anima_train_network.py (LoRA, NO text encoder cache)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# TEST 4: anima_train_network.py - WITH cache_text_encoder_outputs
|
||||
cmd = [python, "anima_train_network.py"] + lora_args + [
|
||||
"--cache_text_encoder_outputs",
|
||||
]
|
||||
results["lora_with_cache"] = run_test(
|
||||
"anima_train_network.py (LoRA, WITH --cache_text_encoder_outputs)",
|
||||
cmd, args.timeout,
|
||||
)
|
||||
|
||||
# SUMMARY
|
||||
print(f"\n{'=' * 70}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 70}")
|
||||
all_pass = True
|
||||
for test_name, result in results.items():
|
||||
status = result["status"]
|
||||
icon = "OK" if status == "PASS" else "FAIL"
|
||||
if status != "PASS":
|
||||
all_pass = False
|
||||
print(f" [{icon:4s}] {test_name}: {result['detail']}")
|
||||
|
||||
print(f"\nTemp directory (can delete): {tmp_dir}")
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
shutil.rmtree(tmp_dir)
|
||||
print("Temp directory cleaned up.")
|
||||
except Exception:
|
||||
print(f"Note: could not clean up {tmp_dir}")
|
||||
|
||||
if all_pass:
|
||||
print("\nAll tests PASSED!")
|
||||
else:
|
||||
print("\nSome tests FAILED!")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
16
tests/test_sdxl_train_leco.py
Normal file
16
tests/test_sdxl_train_leco.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import sdxl_train_leco
|
||||
from library import deepspeed_utils, sdxl_train_util, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert sdxl_train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
15
tests/test_train_leco.py
Normal file
15
tests/test_train_leco.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import train_leco
|
||||
from library import deepspeed_utils, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
319
train_leco.py
Normal file
319
train_leco.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import random
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, strategy_sd, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
build_network_kwargs,
|
||||
concat_embeddings,
|
||||
diffusion,
|
||||
encode_prompt_sd,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
get_save_extension,
|
||||
load_prompt_settings,
|
||||
predict_noise,
|
||||
save_weights,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
|
||||
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
del vae
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
|
||||
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -90,40 +90,23 @@ class NetworkTrainer:
|
||||
if lr_descriptions is not None:
|
||||
lr_desc = lr_descriptions[i]
|
||||
else:
|
||||
idx = i - (0 if args.network_train_unet_only else -1)
|
||||
idx = i - (0 if args.network_train_unet_only else 1)
|
||||
if idx == -1:
|
||||
lr_desc = "textencoder"
|
||||
else:
|
||||
if len(lrs) > 2:
|
||||
lr_desc = f"group{idx}"
|
||||
lr_desc = f"group{i}"
|
||||
else:
|
||||
lr_desc = "unet"
|
||||
|
||||
logs[f"lr/{lr_desc}"] = lr
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
# tracking d*lr value
|
||||
logs[f"lr/d*lr/{lr_desc}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower().startswith("Prodigy".lower()):
|
||||
opt = lr_scheduler.optimizers[-1] if hasattr(lr_scheduler, "optimizers") else optimizer
|
||||
if opt is not None:
|
||||
logs[f"lr/d*lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["lr"]
|
||||
if "effective_lr" in opt.param_groups[i]:
|
||||
logs[f"lr/d*eff_lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["effective_lr"]
|
||||
|
||||
return logs
|
||||
|
||||
@@ -470,7 +453,7 @@ class NetworkTrainer:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
@@ -1085,6 +1068,7 @@ class NetworkTrainer:
|
||||
"enable_bucket": bool(dataset.enable_bucket),
|
||||
"min_bucket_reso": dataset.min_bucket_reso,
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"skip_image_resolution": dataset.skip_image_resolution,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
@@ -1191,6 +1175,7 @@ class NetworkTrainer:
|
||||
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
||||
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
||||
"ss_skip_image_resolution": dataset.skip_image_resolution,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
||||
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
||||
|
||||
Reference in New Issue
Block a user