feat: add multi backend attention and related update for HI2.1 models and scripts

This commit is contained in:
Kohya S
2025-09-20 19:45:33 +09:00
parent f834b2e0d4
commit b090d15f7d
6 changed files with 286 additions and 102 deletions

View File

@@ -126,7 +126,8 @@ accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--sdpa \
--attn_mode="torch" \
--split_attn \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
@@ -175,6 +176,10 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
#### Memory/Speed Related
* `--attn_mode=<choice>`
- Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1.
* `--split_attn`
- Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1.
* `--fp8_scaled`
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option.
* `--fp8_vl`
@@ -429,6 +434,7 @@ python hunyuan_image_minimal_inference.py \
--vae "<path to hunyuan_image_2.1_vae_fp16.safetensors>" \
--lora_weight "<path to your trained LoRA>" \
--lora_multiplier 1.0 \
--attn_mode "torch" \
--prompt "A cute cartoon penguin in a snowy landscape" \
--image_size 2048 2048 \
--infer_steps 50 \
@@ -445,6 +451,8 @@ python hunyuan_image_minimal_inference.py \
- `--guidance_scale`: CFG scale (default: 3.5)
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
`--split_attn` is not supported (since inference is done one at a time).
<details>
<summary>日本語</summary>
@@ -457,6 +465,8 @@ python hunyuan_image_minimal_inference.py \
- `--guidance_scale`: CFGスケール推奨: 3.5
- `--flow_shift`: Flow Matchingシフトパラメータデフォルト: 5.0
`--split_attn`はサポートされていません1件ずつ推論するため
</details>
## 9. Related Tools / 関連ツール