mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
143 Commits
vae_batch_
...
feature-ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9bc6aa14c | ||
|
|
30295c9668 | ||
|
|
999df5ec15 | ||
|
|
88960e6309 | ||
|
|
88dc3213a9 | ||
|
|
1a9bf2ab56 | ||
|
|
8a72f56c9f | ||
|
|
d0b335d8cf | ||
|
|
4e7dfc0b1b | ||
|
|
2e0fcc50cb | ||
|
|
0b90555916 | ||
|
|
9a50c96a68 | ||
|
|
7bd9a6b19e | ||
|
|
3f9eab4946 | ||
|
|
7fb0d30feb | ||
|
|
b4d1152293 | ||
|
|
2fffcb605c | ||
|
|
a87e999786 | ||
|
|
05f392fa27 | ||
|
|
6731d8a57f | ||
|
|
078ee28a94 | ||
|
|
5034c6f813 | ||
|
|
884c1f37c4 | ||
|
|
935e0037dc | ||
|
|
52d13373c0 | ||
|
|
8e4dc1f441 | ||
|
|
0e929f97b9 | ||
|
|
1db78559a6 | ||
|
|
3e6935a07e | ||
|
|
fc40a279fa | ||
|
|
cadcd3169b | ||
|
|
bcd3a5a60a | ||
|
|
77dbabe849 | ||
|
|
d94bed645a | ||
|
|
0145efc2f2 | ||
|
|
bb47f1ea89 | ||
|
|
61eda76278 | ||
|
|
e4d6923409 | ||
|
|
5753b8ff6b | ||
|
|
2bfda1271b | ||
|
|
5b38d07f03 | ||
|
|
e2ed265104 | ||
|
|
e85813200a | ||
|
|
a27ace74d9 | ||
|
|
865c8d55e2 | ||
|
|
7c075a9c8d | ||
|
|
b4a89c3cdf | ||
|
|
f62c68df3c | ||
|
|
a4fae93dce | ||
|
|
1684ababcd | ||
|
|
64430eb9b2 | ||
|
|
d8717a3d1c | ||
|
|
a21b6a917e | ||
|
|
4625b34f4e | ||
|
|
80320d21fe | ||
|
|
29523c9b68 | ||
|
|
fd3a445769 | ||
|
|
13296ae93b | ||
|
|
0e8ac43760 | ||
|
|
bc9252cc1b | ||
|
|
3b25de1f17 | ||
|
|
f0b07c52ab | ||
|
|
309c44bdf2 | ||
|
|
8387e0b95c | ||
|
|
5c50cdbb44 | ||
|
|
46ad3be059 | ||
|
|
abf2c44bc5 | ||
|
|
adb775c616 | ||
|
|
4fc917821a | ||
|
|
899f3454b6 | ||
|
|
b11c053b8f | ||
|
|
c46f08a87a | ||
|
|
0d9da0ea71 | ||
|
|
f501209c37 | ||
|
|
c8af252a44 | ||
|
|
7f984f4775 | ||
|
|
d33d5eccd1 | ||
|
|
7c61c0dfe0 | ||
|
|
26db64be17 | ||
|
|
629073cd9d | ||
|
|
06df0377f9 | ||
|
|
7f93e21f30 | ||
|
|
9f1892cc8e | ||
|
|
1a4f1ff0f1 | ||
|
|
00e12eed65 | ||
|
|
30008168e3 | ||
|
|
1481217eb2 | ||
|
|
61f7283167 | ||
|
|
2ba1cc7791 | ||
|
|
7482784f74 | ||
|
|
e8c15c7167 | ||
|
|
9fe8a47080 | ||
|
|
1f22a94cfe | ||
|
|
5e45df722d | ||
|
|
09c4710d1e | ||
|
|
dfe1ab6c50 | ||
|
|
b6e4194ea5 | ||
|
|
b5d1f1caea | ||
|
|
d6c3e6346e | ||
|
|
800d068e37 | ||
|
|
3817b65b45 | ||
|
|
a69884a209 | ||
|
|
cad182d29a | ||
|
|
a2daa87007 | ||
|
|
1bba7acd9a | ||
|
|
d6f7e2e20c | ||
|
|
9647f1e324 | ||
|
|
42fe22f5a2 | ||
|
|
ce2610d29b | ||
|
|
0886d976f1 | ||
|
|
542f980443 | ||
|
|
70403f6977 | ||
|
|
7b83d50dc0 | ||
|
|
a1a5627b13 | ||
|
|
ce37c08b9a | ||
|
|
5f9047c8cf | ||
|
|
fc772affbe | ||
|
|
653621de57 | ||
|
|
2c94d17f05 | ||
|
|
48e7da2d4a | ||
|
|
ba725a84e9 | ||
|
|
42a801514c | ||
|
|
6d7bec8a37 | ||
|
|
025cca699b | ||
|
|
6597631b90 | ||
|
|
bd16bd13ae | ||
|
|
98efbc3bb7 | ||
|
|
1aa2f00e85 | ||
|
|
3ed7606f88 | ||
|
|
3365cfadd7 | ||
|
|
44782dd790 | ||
|
|
aa36c48685 | ||
|
|
bb7bae5dff | ||
|
|
3ce23b7f16 | ||
|
|
733fdc09c6 | ||
|
|
6965a0178a | ||
|
|
16015635d2 | ||
|
|
60a76ebb72 | ||
|
|
a00b06bc97 | ||
|
|
7323ee1b9d | ||
|
|
c0caf33e3f | ||
|
|
d154e76c45 | ||
|
|
a9c5aa1f93 |
9
.ai/claude.prompt.md
Normal file
9
.ai/claude.prompt.md
Normal file
@@ -0,0 +1,9 @@
|
||||
## About This File
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## 1. Project Context
|
||||
Here is the essential context for our project. Please read and understand it thoroughly.
|
||||
|
||||
### Project Overview
|
||||
@./context/01-overview.md
|
||||
101
.ai/context/01-overview.md
Normal file
101
.ai/context/01-overview.md
Normal file
@@ -0,0 +1,101 @@
|
||||
This file provides the overview and guidance for developers working with the codebase, including setup instructions, architecture details, and common commands.
|
||||
|
||||
## Project Architecture
|
||||
|
||||
### Core Training Framework
|
||||
The codebase is built around a **strategy pattern architecture** that supports multiple diffusion model families:
|
||||
|
||||
- **`library/strategy_base.py`**: Base classes for tokenization, text encoding, latent caching, and training strategies
|
||||
- **`library/strategy_*.py`**: Model-specific implementations for SD, SDXL, SD3, FLUX, etc.
|
||||
- **`library/train_util.py`**: Core training utilities shared across all model types
|
||||
- **`library/config_util.py`**: Configuration management with TOML support
|
||||
|
||||
### Model Support Structure
|
||||
Each supported model family has a consistent structure:
|
||||
- **Training script**: `{model}_train.py` (full fine-tuning), `{model}_train_network.py` (LoRA/network training)
|
||||
- **Model utilities**: `library/{model}_models.py`, `library/{model}_train_utils.py`, `library/{model}_utils.py`
|
||||
- **Networks**: `networks/lora_{model}.py`, `networks/oft_{model}.py` for adapter training
|
||||
|
||||
### Supported Models
|
||||
- **Stable Diffusion 1.x**: `train*.py`, `library/train_util.py`, `train_db.py` (for DreamBooth)
|
||||
- **SDXL**: `sdxl_train*.py`, `library/sdxl_*`
|
||||
- **SD3**: `sd3_train*.py`, `library/sd3_*`
|
||||
- **FLUX.1**: `flux_train*.py`, `library/flux_*`
|
||||
|
||||
### Key Components
|
||||
|
||||
#### Memory Management
|
||||
- **Block swapping**: CPU-GPU memory optimization via `--blocks_to_swap` parameter, works with custom offloading. Only available for models with transformer architectures like SD3 and FLUX.1.
|
||||
- **Custom offloading**: `library/custom_offloading_utils.py` for advanced memory management
|
||||
- **Gradient checkpointing**: Memory reduction during training
|
||||
|
||||
#### Training Features
|
||||
- **LoRA training**: Low-rank adaptation networks in `networks/lora*.py`
|
||||
- **ControlNet training**: Conditional generation control
|
||||
- **Textual Inversion**: Custom embedding training
|
||||
- **Multi-resolution training**: Bucket-based aspect ratio handling
|
||||
- **Validation loss**: Real-time training monitoring, only for LoRA training
|
||||
|
||||
#### Configuration System
|
||||
Dataset configuration uses TOML files with structured validation:
|
||||
```toml
|
||||
[datasets.sample_dataset]
|
||||
resolution = 1024
|
||||
batch_size = 2
|
||||
|
||||
[[datasets.sample_dataset.subsets]]
|
||||
image_dir = "path/to/images"
|
||||
caption_extension = ".txt"
|
||||
```
|
||||
|
||||
## Common Development Commands
|
||||
|
||||
### Training Commands Pattern
|
||||
All training scripts follow this general pattern:
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 {script_name}.py \
|
||||
--pretrained_model_name_or_path model.safetensors \
|
||||
--dataset_config config.toml \
|
||||
--output_dir output \
|
||||
--output_name model_name \
|
||||
[model-specific options]
|
||||
```
|
||||
|
||||
### Memory Optimization
|
||||
For low VRAM environments, use block swapping:
|
||||
```bash
|
||||
# Add to any training command for memory reduction
|
||||
--blocks_to_swap 10 # Swap 10 blocks to CPU (adjust number as needed)
|
||||
```
|
||||
|
||||
### Utility Scripts
|
||||
Located in `tools/` directory:
|
||||
- `tools/merge_lora.py`: Merge LoRA weights into base models
|
||||
- `tools/cache_latents.py`: Pre-cache VAE latents for faster training
|
||||
- `tools/cache_text_encoder_outputs.py`: Pre-cache text encoder outputs
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Strategy Pattern Implementation
|
||||
When adding support for new models, implement the four core strategies:
|
||||
1. `TokenizeStrategy`: Text tokenization handling
|
||||
2. `TextEncodingStrategy`: Text encoder forward pass
|
||||
3. `LatentsCachingStrategy`: VAE encoding/caching
|
||||
4. `TextEncoderOutputsCachingStrategy`: Text encoder output caching
|
||||
|
||||
### Testing Approach
|
||||
- Unit tests focus on utility functions and model loading
|
||||
- Integration tests validate training script syntax and basic execution
|
||||
- Most tests use mocks to avoid requiring actual model files
|
||||
- Add tests for new model support in `tests/test_{model}_*.py`
|
||||
|
||||
### Configuration System
|
||||
- Use `config_util.py` dataclasses for type-safe configuration
|
||||
- Support both command-line arguments and TOML file configuration
|
||||
- Validate configuration early in training scripts to prevent runtime errors
|
||||
|
||||
### Memory Management
|
||||
- Always consider VRAM limitations when implementing features
|
||||
- Use gradient checkpointing for large models
|
||||
- Implement block swapping for models with transformer architectures
|
||||
- Cache intermediate results (latents, text embeddings) when possible
|
||||
9
.ai/gemini.prompt.md
Normal file
9
.ai/gemini.prompt.md
Normal file
@@ -0,0 +1,9 @@
|
||||
## About This File
|
||||
|
||||
This file provides guidance to Gemini CLI (https://github.com/google-gemini/gemini-cli) when working with code in this repository.
|
||||
|
||||
## 1. Project Context
|
||||
Here is the essential context for our project. Please read and understand it thoroughly.
|
||||
|
||||
### Project Overview
|
||||
@./context/01-overview.md
|
||||
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: kohya-ss
|
||||
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
- dev
|
||||
- sd3
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -40,7 +43,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
3
.github/workflows/typos.yml
vendored
3
.github/workflows/typos.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -6,3 +6,8 @@ venv
|
||||
build
|
||||
.vscode
|
||||
wandb
|
||||
CLAUDE.md
|
||||
GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
|
||||
72
README.md
72
README.md
@@ -9,11 +9,25 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Jul 10, 2025:
|
||||
- [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.
|
||||
|
||||
May 1, 2025:
|
||||
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
|
||||
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
Apr 27, 2025:
|
||||
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
|
||||
- See [here](#sample-image-generation-during-training) for details.
|
||||
- If you have any issues with this, please let us know.
|
||||
|
||||
Apr 6, 2025:
|
||||
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
|
||||
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
|
||||
@@ -43,46 +57,30 @@ Jan 25, 2025:
|
||||
- It will be added to other scripts as well.
|
||||
- As a current limitation, validation loss is not supported when `--block_to_swap` is specified, or when schedule-free optimizer is used.
|
||||
|
||||
Dec 15, 2024:
|
||||
## For Developers Using AI Coding Agents
|
||||
|
||||
- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!
|
||||
- Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`.
|
||||
- Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler.
|
||||
This repository provides recommended instructions to help AI agents like Claude and Gemini understand our project context and coding standards.
|
||||
|
||||
Dec 7, 2024:
|
||||
To use them, you need to opt-in by creating your own configuration file in the project root.
|
||||
|
||||
- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds!
|
||||
<!--
|
||||
Also, the ControlNet training script for SD has been changed from `train_controlnet.py` to `train_control_net.py`.
|
||||
- `train_controlnet.py` is still available, but it will be removed in the future.
|
||||
-->
|
||||
**Quick Setup:**
|
||||
|
||||
- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.
|
||||
1. Create a `CLAUDE.md` and/or `GEMINI.md` file in the project root.
|
||||
2. Add the following line to your `CLAUDE.md` to import the repository's recommended prompt:
|
||||
|
||||
Dec 3, 2024:
|
||||
```markdown
|
||||
@./.ai/claude.prompt.md
|
||||
```
|
||||
|
||||
-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training).
|
||||
or for Gemini:
|
||||
|
||||
Dec 2, 2024:
|
||||
```markdown
|
||||
@./.ai/gemini.prompt.md
|
||||
```
|
||||
|
||||
- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details.
|
||||
- Not fully tested. Feedback is welcome.
|
||||
- 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution.
|
||||
- Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required.
|
||||
- Multi-GPU training is not tested.
|
||||
3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`).
|
||||
|
||||
Dec 1, 2024:
|
||||
|
||||
- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris!
|
||||
- Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available.
|
||||
|
||||
- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO!
|
||||
|
||||
Nov 14, 2024:
|
||||
|
||||
- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM.
|
||||
- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved.
|
||||
- There may be bugs due to the significant changes. Feedback is welcome.
|
||||
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so it won't be committed to the repository.
|
||||
|
||||
## FLUX.1 training
|
||||
|
||||
@@ -870,6 +868,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
@@ -1344,11 +1350,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
|
||||
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
302
docs/lumina_train_network.md
Normal file
302
docs/lumina_train_network.md
Normal file
@@ -0,0 +1,302 @@
|
||||
Status: reviewed
|
||||
|
||||
# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド
|
||||
|
||||
This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository.
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE).
|
||||
|
||||
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the train_network.py guide (to be documented). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* The `sd-scripts` repository has been cloned and the Python environment is ready.
|
||||
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
|
||||
* Lumina Image 2.0 model files for training are available.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、`train_network.py`のガイド(作成中)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
**前提条件:**
|
||||
|
||||
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
|
||||
* 学習対象のLumina Image 2.0モデルファイルが準備できていること。
|
||||
</details>
|
||||
|
||||
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||
|
||||
`lumina_train_network.py` is based on `train_network.py` but modified for Lumina Image 2.0. Main differences are:
|
||||
|
||||
* **Target models:** Lumina Image 2.0 models.
|
||||
* **Model structure:** Uses Next-DiT (Transformer based) instead of U-Net and employs a single text encoder (Gemma2). The AutoEncoder (AE) is not compatible with SDXL/SD3/FLUX.
|
||||
* **Arguments:** Options exist to specify the Lumina Image 2.0 model, Gemma2 text encoder and AE. With a single `.safetensors` file, these components are typically provided separately.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
|
||||
* **Lumina specific options:** Additional parameters for timestep sampling, model prediction type, discrete flow shift, and system prompt.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
`lumina_train_network.py`は`train_network.py`をベースに、Lumina Image 2.0モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** Lumina Image 2.0モデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにNext-DiT (Transformerベース) を使用します。Text EncoderとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
|
||||
* **引数:** Lumina Image 2.0モデル、Gemma2 Text Encoder、AEを指定する引数があります。通常、これらのコンポーネントは個別に提供されます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はLumina Image 2.0の学習では使用されません。
|
||||
* **Lumina特有の引数:** タイムステップのサンプリング、モデル予測タイプ、離散フローシフト、システムプロンプトに関する引数が追加されています。
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
|
||||
The following files are required before starting training:
|
||||
|
||||
1. **Training script:** `lumina_train_network.py`
|
||||
2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model.
|
||||
3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder.
|
||||
4. **AutoEncoder (AE) file:** `.safetensors` file for the AE.
|
||||
5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md). In this document we use `my_lumina_dataset_config.toml` as an example.
|
||||
|
||||
|
||||
**Model Files:**
|
||||
* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors))
|
||||
* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors))
|
||||
* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX)
|
||||
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習を開始する前に、以下のファイルが必要です。
|
||||
|
||||
1. **学習スクリプト:** `lumina_train_network.py`
|
||||
2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。
|
||||
3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。
|
||||
4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。
|
||||
5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。
|
||||
* 例として`my_lumina_dataset_config.toml`を使用します。
|
||||
|
||||
**モデルファイル** は英語ドキュメントの通りです。
|
||||
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
|
||||
Execute `lumina_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Lumina Image 2.0 specific options must be supplied.
|
||||
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
|
||||
--pretrained_model_name_or_path="lumina-image-2.safetensors" \
|
||||
--gemma2="gemma-2-2b.safetensors" \
|
||||
--ae="ae.safetensors" \
|
||||
--dataset_config="my_lumina_dataset_config.toml" \
|
||||
--output_dir="./output" \
|
||||
--output_name="my_lumina_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_lumina \
|
||||
--network_dim=8 \
|
||||
--network_alpha=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sampling="nextdit_shift" \
|
||||
--discrete_flow_shift=6.0 \
|
||||
--model_prediction_type="raw" \
|
||||
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs
|
||||
```
|
||||
|
||||
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習は、ターミナルから`lumina_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Lumina Image 2.0特有の引数を指定する必要があります。
|
||||
|
||||
以下に、基本的なコマンドライン実行例を示します。
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
|
||||
--pretrained_model_name_or_path="lumina-image-2.safetensors" \
|
||||
--gemma2="gemma-2-2b.safetensors" \
|
||||
--ae="ae.safetensors" \
|
||||
--dataset_config="my_lumina_dataset_config.toml" \
|
||||
--output_dir="./output" \
|
||||
--output_name="my_lumina_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_lumina \
|
||||
--network_dim=8 \
|
||||
--network_alpha=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sampling="nextdit_shift" \
|
||||
--discrete_flow_shift=6.0 \
|
||||
--model_prediction_type="raw" \
|
||||
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs
|
||||
```
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
|
||||
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Lumina Image 2.0 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide.
|
||||
|
||||
#### Model Options / モデル関連
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Lumina model>"` **required** – Path to the Lumina Image 2.0 model.
|
||||
* `--gemma2="<path to Gemma2 model>"` **required** – Path to the Gemma2 text encoder `.safetensors` file.
|
||||
* `--ae="<path to AE model>"` **required** – Path to the AutoEncoder `.safetensors` file.
|
||||
|
||||
#### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ
|
||||
|
||||
* `--gemma2_max_token_length=<integer>` – Max token length for Gemma2. Default is 256.
|
||||
* `--timestep_sampling=<choice>` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `shift`. **Recommended: `nextdit_shift`**
|
||||
* `--discrete_flow_shift=<float>` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`.
|
||||
* `--model_prediction_type=<choice>` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`**
|
||||
* `--system_prompt=<string>` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training.
|
||||
* `--sigmoid_scale=<float>` – Scale factor for sigmoid timestep sampling. Default `1.0`.
|
||||
|
||||
#### Memory and Speed / メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[experimental]** – Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`.
|
||||
* `--cache_text_encoder_outputs` – Cache Gemma2 outputs to reduce memory usage.
|
||||
* `--cache_latents`, `--cache_latents_to_disk` – Cache AE outputs.
|
||||
* `--fp8_base` – Use FP8 precision for the base model.
|
||||
|
||||
#### Network Arguments / ネットワーク引数
|
||||
|
||||
For Lumina Image 2.0, you can specify different dimensions for various components:
|
||||
|
||||
* `--network_args` can include:
|
||||
* `"attn_dim=4"` – Attention dimension
|
||||
* `"mlp_dim=4"` – MLP dimension
|
||||
* `"mod_dim=4"` – Modulation dimension
|
||||
* `"refiner_dim=4"` – Refiner blocks dimension
|
||||
* `"embedder_dims=[4,4,4]"` – Embedder dimensions for x, t, and caption embedders
|
||||
|
||||
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` – Options for Stable Diffusion v1/v2 that are not used for Lumina Image 2.0.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
|
||||
|
||||
#### モデル関連
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Lumina model>"` **[必須]**
|
||||
* 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイルのパスを指定します。
|
||||
* `--gemma2="<path to Gemma2 model>"` **[必須]**
|
||||
* Gemma2テキストエンコーダーの`.safetensors`ファイルのパスを指定します。
|
||||
* `--ae="<path to AE model>"` **[必須]**
|
||||
* AutoEncoderの`.safetensors`ファイルのパスを指定します。
|
||||
|
||||
#### Lumina Image 2.0 学習パラメータ
|
||||
|
||||
* `--gemma2_max_token_length=<integer>` – Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。
|
||||
* `--timestep_sampling=<choice>` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`shift`です。**推奨: `nextdit_shift`**
|
||||
* `--discrete_flow_shift=<float>` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。
|
||||
* `--model_prediction_type=<choice>` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`**
|
||||
* `--system_prompt=<string>` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。
|
||||
* `--sigmoid_scale=<float>` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[実験的機能]** – TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` – Gemma2の出力をキャッシュしてメモリ使用量を削減します。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` – AEの出力をキャッシュします。
|
||||
* `--fp8_base` – ベースモデルにFP8精度を使用します。
|
||||
|
||||
#### ネットワーク引数
|
||||
|
||||
Lumina Image 2.0では、各コンポーネントに対して異なる次元を指定できます:
|
||||
|
||||
* `--network_args` には以下を含めることができます:
|
||||
* `"attn_dim=4"` – アテンション次元
|
||||
* `"mlp_dim=4"` – MLP次元
|
||||
* `"mod_dim=4"` – モジュレーション次元
|
||||
* `"refiner_dim=4"` – リファイナーブロック次元
|
||||
* `"embedder_dims=[4,4,4]"` – x、t、キャプションエンベッダーのエンベッダー次元
|
||||
|
||||
#### 非互換・非推奨の引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` – Stable Diffusion v1/v2向けの引数のため、Lumina Image 2.0学習では使用されません。
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
|
||||
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
|
||||
|
||||
## 5. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes.
|
||||
|
||||
## 6. Others / その他
|
||||
|
||||
`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`.
|
||||
|
||||
### 6.1. Recommended Settings / 推奨設定
|
||||
|
||||
Based on the contributor's recommendations, here are the suggested settings for optimal training:
|
||||
|
||||
**Key Parameters:**
|
||||
* `--timestep_sampling="nextdit_shift"`
|
||||
* `--discrete_flow_shift=6.0`
|
||||
* `--model_prediction_type="raw"`
|
||||
* `--mixed_precision="bf16"`
|
||||
|
||||
**System Prompts:**
|
||||
* General purpose: `"You are an assistant designed to generate high-quality images based on user prompts."`
|
||||
* High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
|
||||
**Sample Prompts:**
|
||||
Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) parameters:
|
||||
* `--ctr 0.25 --rcfg 1.0` (default values)
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||
|
||||
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
|
||||
|
||||
`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。
|
||||
|
||||
### 6.1. 推奨設定
|
||||
|
||||
コントリビューターの推奨に基づく、最適な学習のための推奨設定:
|
||||
|
||||
**主要パラメータ:**
|
||||
* `--timestep_sampling="nextdit_shift"`
|
||||
* `--discrete_flow_shift=6.0`
|
||||
* `--model_prediction_type="raw"`
|
||||
* `--mixed_precision="bf16"`
|
||||
|
||||
**システムプロンプト:**
|
||||
* 汎用目的: `"You are an assistant designed to generate high-quality images based on user prompts."`
|
||||
* 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
|
||||
**サンプルプロンプト:**
|
||||
サンプルプロンプトには CFG truncate (`--ctr`) と Renorm CFG (`--rcfg`) パラメータを含めることができます:
|
||||
* `--ctr 0.25 --rcfg 1.0` (デフォルト値)
|
||||
|
||||
</details>
|
||||
@@ -97,15 +97,19 @@ def main(args):
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(model_location, SUB_DIR),
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
@@ -146,7 +150,7 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU_FP32"}],
|
||||
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
|
||||
614
library/chroma_models.py
Normal file
614
library/chroma_models.py
Normal file
@@ -0,0 +1,614 @@
|
||||
# copy from the official repo: https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py
|
||||
# and modified
|
||||
# licensed under Apache License 2.0
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from .flux_models import (
|
||||
attention,
|
||||
rope,
|
||||
apply_rope,
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
)
|
||||
|
||||
|
||||
def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks):
|
||||
"""
|
||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
|
||||
"""
|
||||
batch_size, vectors, dim = tensor.shape
|
||||
|
||||
block_dict = {}
|
||||
|
||||
# HARD CODED VALUES! lookup table for the generated vectors
|
||||
# TODO: move this into chroma config!
|
||||
# Add 38 single mod blocks
|
||||
for i in range(depth_single_blocks):
|
||||
key = f"single_blocks.{i}.modulation.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 image double blocks
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.img_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 text double blocks
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.txt_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add the final layer
|
||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
||||
# 6.2b version
|
||||
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
||||
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
||||
|
||||
idx = 0 # Index to keep track of the vector slices
|
||||
|
||||
for key in block_dict.keys():
|
||||
if "single_blocks" in key:
|
||||
# Single block: 1 ModulationOut
|
||||
block_dict[key] = ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors
|
||||
|
||||
elif "img_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "txt_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "final_layer" in key:
|
||||
# Final layer: 1 ModulationOut
|
||||
block_dict[key] = [
|
||||
tensor[:, idx : idx + 1, :],
|
||||
tensor[:, idx + 1 : idx + 2, :],
|
||||
]
|
||||
idx += 2 # Advance by 3 vectors
|
||||
|
||||
return block_dict
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
def _modulation_shift_scale_fn(x, scale, shift):
|
||||
return (1 + scale) * x + shift
|
||||
|
||||
|
||||
def _modulation_gate_fn(x, gate, gate_params):
|
||||
return x + gate * gate_params
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def modulation_gate_fn(self, x, gate, gate_params):
|
||||
return _modulation_gate_fn(x, gate, gate_params)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
pe: Tensor,
|
||||
distill_vec: list[ModulationOut],
|
||||
mask: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
# replaced with compiled fn
|
||||
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
# replaced with compiled fn
|
||||
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
# replaced with compiled fn
|
||||
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img = self.modulation_gate_fn(
|
||||
img,
|
||||
img_mod2.gate,
|
||||
self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)),
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
# replaced with compiled fn
|
||||
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt = self.modulation_gate_fn(
|
||||
txt,
|
||||
txt_mod2.gate,
|
||||
self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)),
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def modulation_gate_fn(self, x, gate, gate_params):
|
||||
return _modulation_gate_fn(x, gate, gate_params)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor:
|
||||
mod = distill_vec
|
||||
# replaced with compiled fn
|
||||
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
# replaced with compiled fn
|
||||
# return x + mod.gate * output
|
||||
return self.modulation_gate_fn(x, mod.gate, output)
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor:
|
||||
shift, scale = distill_vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
# replaced with compiled fn
|
||||
# x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.modulation_shift_scale_fn(self.norm_final(x), scale[:, None, :], shift[:, None, :])
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
approximator_in_dim: int
|
||||
approximator_depth: int
|
||||
approximator_hidden_size: int
|
||||
_use_compiled: bool
|
||||
|
||||
|
||||
chroma_params = ChromaParams(
|
||||
in_channels=64,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
approximator_in_dim=64,
|
||||
approximator_depth=5,
|
||||
approximator_hidden_size=5120,
|
||||
_use_compiled=False,
|
||||
)
|
||||
|
||||
|
||||
def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
|
||||
"""
|
||||
Modifies attention mask to allow attention to a few extra padding tokens.
|
||||
|
||||
Args:
|
||||
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
|
||||
max_seq_length: Maximum sequence length of the model
|
||||
num_extra_padding: Number of padding tokens to unmask
|
||||
|
||||
Returns:
|
||||
Modified mask
|
||||
"""
|
||||
# Get the actual sequence length from the mask
|
||||
seq_length = mask.sum(dim=-1)
|
||||
batch_size = mask.shape[0]
|
||||
|
||||
modified_mask = mask.clone()
|
||||
|
||||
for i in range(batch_size):
|
||||
current_seq_len = int(seq_length[i].item())
|
||||
|
||||
# Only add extra padding tokens if there's room
|
||||
if current_seq_len < max_seq_length:
|
||||
# Calculate how many padding tokens we can unmask
|
||||
available_padding = max_seq_length - current_seq_len
|
||||
tokens_to_unmask = min(num_extra_padding, available_padding)
|
||||
|
||||
# Unmask the specified number of padding tokens right after the sequence
|
||||
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
|
||||
|
||||
return modified_mask
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ChromaParams):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
|
||||
# TODO: need proper mapping for this approximator output!
|
||||
# currently the mapping is hardcoded in distribute_modulations function
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
params.approximator_in_dim,
|
||||
self.hidden_size,
|
||||
params.approximator_hidden_size,
|
||||
params.approximator_depth,
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(
|
||||
self.hidden_size,
|
||||
1,
|
||||
self.out_channels,
|
||||
)
|
||||
|
||||
# TODO: move this hardcoded value to config
|
||||
# single layer has 3 modulation vectors
|
||||
# double layer has 6 modulation vectors for each expert
|
||||
# final layer has 2 modulation vectors
|
||||
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
|
||||
self.depth_single_blocks = params.depth_single_blocks
|
||||
self.depth_double_blocks = params.depth
|
||||
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
|
||||
self.register_buffer(
|
||||
"mod_index",
|
||||
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
|
||||
persistent=False,
|
||||
)
|
||||
self.approximator_in_dim = params.approximator_in_dim
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor,
|
||||
attn_padding: int = 1,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
# TODO:
|
||||
# need to fix grad accumulation issue here for now it's in no grad mode
|
||||
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
|
||||
# the fan out operation here is deleting the backward graph
|
||||
# alternatively doing forward pass for every block manually is doable but slow
|
||||
# custom backward probably be better
|
||||
with torch.no_grad():
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = (
|
||||
torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
|
||||
)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
# compute mask
|
||||
# assume max seq length from the batched input
|
||||
|
||||
max_len = txt.shape[1]
|
||||
|
||||
# mask
|
||||
with torch.no_grad():
|
||||
txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding)
|
||||
txt_img_mask = torch.cat(
|
||||
[
|
||||
txt_mask_w_padding,
|
||||
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float()
|
||||
txt_img_mask = txt_img_mask[None, None, ...].repeat(txt.shape[0], self.num_heads, 1, 1).int().bool()
|
||||
# txt_mask_w_padding[txt_mask_w_padding==False] = True
|
||||
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
# the guidance replaced by FFN output
|
||||
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
|
||||
double_mod = [img_mod, txt_mod]
|
||||
|
||||
# just in case in different GPU for simple pipeline parallel
|
||||
if self.training:
|
||||
img, txt = ckpt.checkpoint(block, img, txt, pe, double_mod, txt_img_mask)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
if self.training:
|
||||
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
|
||||
else:
|
||||
img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
||||
img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -1,6 +1,6 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Callable, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
|
||||
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
stream = torch.Stream(device="cuda")
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
"""
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
@@ -148,13 +149,16 @@ class Offloader:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
|
||||
# Gradient tensors
|
||||
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
self.remove_handles = []
|
||||
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
|
||||
if self.debug:
|
||||
print(f"Backward hook for block {block_index}")
|
||||
|
||||
@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
|
||||
return
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
if block_idx >= self.blocks_to_swap:
|
||||
|
||||
@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||
else:
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -977,10 +977,10 @@ class Flux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
self.double_blocks = nn.ModuleList()
|
||||
self.single_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def sample_images(
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None
|
||||
controlnet=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
@@ -67,7 +67,7 @@ def sample_images(
|
||||
# unwrap unet and text_encoder(s)
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
if text_encoders is not None:
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
|
||||
if controlnet is not None:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
@@ -101,7 +101,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
@@ -125,7 +125,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@@ -147,14 +147,16 @@ def sample_image_inference(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -162,8 +164,8 @@ def sample_image_inference(
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
@@ -173,16 +175,21 @@ def sample_image_inference(
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"CFG scale: {cfg_scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
@@ -191,26 +198,37 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
def encode_prompt(prpt):
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prpt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
return text_encoder_conds
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
|
||||
# encode negative prompts
|
||||
if cfg_scale != 1.0:
|
||||
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
|
||||
neg_t5_attn_mask = (
|
||||
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
|
||||
)
|
||||
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
|
||||
else:
|
||||
neg_cond = None
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -235,7 +253,20 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance=emb_guidance_scale,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
controlnet=controlnet,
|
||||
controlnet_img=controlnet_image,
|
||||
neg_cond=neg_cond,
|
||||
)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -305,22 +336,24 @@ def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt: torch.Tensor, # t5_out
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
vec: torch.Tensor, # l_pooled
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
do_cfg = neg_cond is not None
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
@@ -336,20 +369,48 @@ def denoise(
|
||||
else:
|
||||
block_samples = None
|
||||
block_single_samples = None
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
if not do_cfg:
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
else:
|
||||
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
|
||||
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
|
||||
# TODO is it ok to use the same block samples for both cond and uncond?
|
||||
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
|
||||
block_single_samples = (
|
||||
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
|
||||
)
|
||||
|
||||
nc_c_pred = model(
|
||||
img=torch.cat([img, img], dim=0),
|
||||
img_ids=torch.cat([img_ids, img_ids], dim=0),
|
||||
txt=torch.cat([neg_t5_out, txt], dim=0),
|
||||
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
pred = neg_pred + (pred - neg_pred) * cfg_scale
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
@@ -433,7 +494,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
@@ -458,7 +519,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
@@ -569,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
import contextlib
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
legacy = True
|
||||
has_ipex = True
|
||||
except Exception:
|
||||
legacy = False
|
||||
has_ipex = False
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
torch_version = float(torch.__version__[:3])
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
@@ -16,7 +17,10 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||
return True, "Skipping IPEX hijack"
|
||||
else:
|
||||
try: # force xpu device on torch compile and triton
|
||||
try:
|
||||
# force xpu device on torch compile and triton
|
||||
# import inductor utils to get around lazy import
|
||||
from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401
|
||||
torch._inductor.utils.GPU_TYPES = ["xpu"]
|
||||
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
|
||||
from triton import backends as triton_backends # pylint: disable=import-error
|
||||
@@ -35,7 +39,6 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
@@ -45,7 +48,6 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
@@ -58,7 +60,6 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
@@ -70,47 +71,40 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
if legacy:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
if float(ipex.__version__[:3]) < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
if torch_version < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
|
||||
if not legacy or float(ipex.__version__[:3]) >= 2.3:
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
else:
|
||||
torch.cuda._initialization_lock = torch.xpu._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
|
||||
@@ -120,12 +114,24 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
if torch_version < 2.5:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
|
||||
if torch_version < 2.7:
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.List = torch.xpu.List
|
||||
|
||||
|
||||
# Memory:
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
|
||||
if legacy:
|
||||
if has_ipex:
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
@@ -153,40 +159,19 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# AMP:
|
||||
if legacy:
|
||||
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
if float(ipex.__version__[:3]) < 2.3:
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
# C
|
||||
if legacy and float(ipex.__version__[:3]) < 2.3:
|
||||
if torch_version < 2.3:
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
ipex._C._DeviceProperties.minor = 1
|
||||
ipex._C._DeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
|
||||
else:
|
||||
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
|
||||
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
|
||||
torch._C._XpuDeviceProperties.major = 12
|
||||
torch._C._XpuDeviceProperties.minor = 1
|
||||
torch._C._XpuDeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
|
||||
|
||||
# Fix functions with ipex:
|
||||
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||
@@ -195,21 +180,22 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True)
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
|
||||
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"])
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.get_device_properties.L2_cache_size = 16*1024*1024 # A770 and A750
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
|
||||
device_supports_fp64 = ipex_hijacks()
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
|
||||
ipex_diffusers(device_supports_fp64=device_supports_fp64)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
torch.cuda.is_xpu_hijacked = True
|
||||
|
||||
@@ -61,13 +61,13 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
is_unsqueezed = False
|
||||
if len(query.shape) == 3:
|
||||
if query.dim() == 3:
|
||||
query = query.unsqueeze(0)
|
||||
is_unsqueezed = True
|
||||
if len(key.shape) == 3:
|
||||
key = key.unsqueeze(0)
|
||||
if len(value.shape) == 3:
|
||||
value = value.unsqueeze(0)
|
||||
if key.dim() == 3:
|
||||
key = key.unsqueeze(0)
|
||||
if value.dim() == 3:
|
||||
value = value.unsqueeze(0)
|
||||
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
|
||||
|
||||
# Slice SDPA
|
||||
@@ -115,5 +115,5 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop
|
||||
else:
|
||||
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
if is_unsqueezed:
|
||||
hidden_states.squeeze(0)
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
return hidden_states
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from functools import wraps
|
||||
import torch
|
||||
import diffusers # pylint: disable=import-error
|
||||
from diffusers.utils import torch_utils # pylint: disable=import-error, unused-import # noqa: F401
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
|
||||
# Diffusers FreeU
|
||||
# Diffusers is imported before ipex hijacks so fourier_filter needs hijacking too
|
||||
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
||||
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
||||
def fourier_filter(x_in, threshold, scale):
|
||||
@@ -41,7 +43,84 @@ class FluxPosEmbed(torch.nn.Module):
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
|
||||
def hidream_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
return_device = pos.device
|
||||
pos = pos.to("cpu")
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.to(return_device, dtype=torch.float32)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
||||
if output_type == "np":
|
||||
return diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float32)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.outer(pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, use_real: bool = True, use_real_unbind_dim: int = -1):
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
else:
|
||||
# used for lumina
|
||||
# force cpu with Alchemist
|
||||
x_rotated = torch.view_as_complex(x.to("cpu").float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.to("cpu").unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x).to(x.device)
|
||||
|
||||
|
||||
def ipex_diffusers(device_supports_fp64=False):
|
||||
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
||||
if not device_supports_fp64:
|
||||
# get around lazy imports
|
||||
from diffusers.models import embeddings as diffusers_embeddings # pylint: disable=import-error, unused-import # noqa: F401
|
||||
from diffusers.models import transformers as diffusers_transformers # pylint: disable=import-error, unused-import # noqa: F401
|
||||
from diffusers.models import controlnets as diffusers_controlnets # pylint: disable=import-error, unused-import # noqa: F401
|
||||
diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid = get_1d_sincos_pos_embed_from_grid
|
||||
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
|
||||
diffusers.models.embeddings.apply_rotary_emb = apply_rotary_emb
|
||||
diffusers.models.transformers.transformer_flux.FluxPosEmbed = FluxPosEmbed
|
||||
diffusers.models.transformers.transformer_lumina2.apply_rotary_emb = apply_rotary_emb
|
||||
diffusers.models.controlnets.controlnet_flux.FluxPosEmbed = FluxPosEmbed
|
||||
diffusers.models.transformers.transformer_hidream_image.rope = hidream_rope
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
# sync grad to master weight
|
||||
if hasattr(optimizer, "sync_grad"):
|
||||
optimizer.sync_grad()
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
if param.grad.is_sparse:
|
||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||
# so we should check the coalesced _values().
|
||||
if param.grad.dtype is torch.float16:
|
||||
param.grad = param.grad.coalesce()
|
||||
to_unscale = param.grad._values()
|
||||
else:
|
||||
to_unscale = param.grad
|
||||
|
||||
# -: is there a way to split by device and dtype without appending in the inner loop?
|
||||
to_unscale = to_unscale.to("cpu")
|
||||
per_device_and_dtype_grads[to_unscale.device][
|
||||
to_unscale.dtype
|
||||
].append(to_unscale)
|
||||
|
||||
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
core._amp_foreach_non_finite_check_and_unscale_(
|
||||
grads,
|
||||
per_device_found_inf.get("cpu"),
|
||||
per_device_inv_scale.get("cpu"),
|
||||
)
|
||||
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
||||
raise RuntimeError(
|
||||
"unscale_() has already been called on this optimizer since the last update()."
|
||||
)
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
if device_supports_fp64:
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
else:
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||
)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, inv_scale, found_inf, False
|
||||
)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
Args:
|
||||
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [
|
||||
found_inf.to(device="cpu", non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
to_device = _scale.device
|
||||
_scale = _scale.to("cpu")
|
||||
_growth_tracker = _growth_tracker.to("cpu")
|
||||
|
||||
core._amp_update_scale_(
|
||||
_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval,
|
||||
)
|
||||
|
||||
_scale = _scale.to(to_device)
|
||||
_growth_tracker = _growth_tracker.to(to_device)
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def gradscaler_init():
|
||||
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
||||
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
||||
torch.xpu.amp.GradScaler.update = update
|
||||
return torch.xpu.amp.GradScaler
|
||||
@@ -4,17 +4,23 @@ from contextlib import nullcontext
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
|
||||
try:
|
||||
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
|
||||
del x
|
||||
torch.xpu.empty_cache()
|
||||
can_allocate_plus_4gb = True
|
||||
except Exception:
|
||||
can_allocate_plus_4gb = False
|
||||
torch_version = float(torch.__version__[:3])
|
||||
current_xpu_device = f"xpu:{torch.xpu.current_device()}"
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(current_xpu_device).has_fp64
|
||||
|
||||
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0':
|
||||
if (torch.xpu.get_device_properties(current_xpu_device).total_memory / 1024 / 1024 / 1024) > 4.1:
|
||||
try:
|
||||
x = torch.ones((33000,33000), dtype=torch.float32, device=current_xpu_device)
|
||||
del x
|
||||
torch.xpu.empty_cache()
|
||||
use_dynamic_attention = False
|
||||
except Exception:
|
||||
use_dynamic_attention = True
|
||||
else:
|
||||
use_dynamic_attention = True
|
||||
else:
|
||||
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
|
||||
use_dynamic_attention = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '1')
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
@@ -22,32 +28,67 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
return module.to("xpu")
|
||||
return module.to(f"xpu:{torch.xpu.current_device()}")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
||||
return self.device.type == "xpu" or self.device.type == "cuda"
|
||||
|
||||
def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
def check_device_type(device, device_type: str) -> bool:
|
||||
if device is None or type(device) not in {str, int, torch.device}:
|
||||
return False
|
||||
else:
|
||||
return bool(torch.device(device).type == device_type)
|
||||
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
||||
def check_cuda(device) -> bool:
|
||||
return bool(isinstance(device, int) or check_device_type(device, "cuda"))
|
||||
|
||||
def return_xpu(device): # keep the device instance type, aka return string if the input is string
|
||||
return f"xpu:{torch.xpu.current_device()}" if device is None else f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
|
||||
# Autocast
|
||||
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
||||
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
||||
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type == "cuda":
|
||||
def autocast_init(self, device_type=None, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type is None or check_cuda(device_type):
|
||||
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
else:
|
||||
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
|
||||
|
||||
original_grad_scaler_init = torch.amp.grad_scaler.GradScaler.__init__
|
||||
@wraps(torch.amp.grad_scaler.GradScaler.__init__)
|
||||
def GradScaler_init(self, device: str = None, init_scale: float = 2.0**16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True):
|
||||
if device is None or check_cuda(device):
|
||||
return original_grad_scaler_init(self, device=return_xpu(device), init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
|
||||
else:
|
||||
return original_grad_scaler_init(self, device=device, init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
|
||||
|
||||
|
||||
original_is_autocast_enabled = torch.is_autocast_enabled
|
||||
@wraps(torch.is_autocast_enabled)
|
||||
def torch_is_autocast_enabled(device_type=None):
|
||||
if device_type is None or check_cuda(device_type):
|
||||
return original_is_autocast_enabled(return_xpu(device_type))
|
||||
else:
|
||||
return original_is_autocast_enabled(device_type)
|
||||
|
||||
|
||||
original_get_autocast_dtype = torch.get_autocast_dtype
|
||||
@wraps(torch.get_autocast_dtype)
|
||||
def torch_get_autocast_dtype(device_type=None):
|
||||
if device_type is None or check_cuda(device_type) or check_device_type(device_type, "xpu"):
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return original_get_autocast_dtype(device_type)
|
||||
|
||||
|
||||
# Latent Antialias CPU Offload:
|
||||
# IPEX 2.5 and above has partial support but doesn't really work most of the time.
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
@@ -66,23 +107,22 @@ original_from_numpy = torch.from_numpy
|
||||
@wraps(torch.from_numpy)
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
return original_from_numpy(ndarray.astype("float32"))
|
||||
else:
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
original_as_tensor = torch.as_tensor
|
||||
@wraps(torch.as_tensor)
|
||||
def as_tensor(data, dtype=None, device=None):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
device = return_xpu(device)
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
||||
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not check_device_type(device, "cpu"):
|
||||
return original_as_tensor(data, dtype=torch.float32, device=device)
|
||||
else:
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if can_allocate_plus_4gb:
|
||||
if not use_dynamic_attention:
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 32 bit attention workarounds for Alchemist:
|
||||
@@ -106,7 +146,7 @@ original_torch_bmm = torch.bmm
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
mat2 = mat2.to(dtype=input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
# Diffusers FreeU
|
||||
@@ -195,38 +235,36 @@ original_torch_tensor = torch.tensor
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||
if check_device_type(device, "xpu"):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
torch.Tensor.original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
if check_cuda(device):
|
||||
return self.original_Tensor_to(return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
return self.original_Tensor_to(device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
@wraps(torch.Tensor.cuda)
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
if device is None or check_cuda(device):
|
||||
return self.to(return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
||||
@wraps(torch.Tensor.pin_memory)
|
||||
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||
if device is None:
|
||||
device = "xpu"
|
||||
if check_device(device):
|
||||
if device is None or check_cuda(device):
|
||||
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
||||
@@ -234,23 +272,32 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
if torch_version >= 2.4:
|
||||
original_UntypedStorage_to = torch.UntypedStorage.to
|
||||
@wraps(torch.UntypedStorage.to)
|
||||
def UntypedStorage_to(self, *args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
return original_UntypedStorage_to(self, *args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_to(self, *args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
|
||||
if device is None or check_cuda(device):
|
||||
return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
@@ -260,7 +307,7 @@ original_torch_randn = torch.randn
|
||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
if dtype is bytes:
|
||||
dtype = None
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
@@ -268,7 +315,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
original_torch_ones = torch.ones
|
||||
@wraps(torch.ones)
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
@@ -276,7 +323,7 @@ def torch_ones(*args, device=None, **kwargs):
|
||||
original_torch_zeros = torch.zeros
|
||||
@wraps(torch.zeros)
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
@@ -284,7 +331,7 @@ def torch_zeros(*args, device=None, **kwargs):
|
||||
original_torch_full = torch.full
|
||||
@wraps(torch.full)
|
||||
def torch_full(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_torch_full(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_full(*args, device=device, **kwargs)
|
||||
@@ -292,63 +339,91 @@ def torch_full(*args, device=None, **kwargs):
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_eye = torch.eye
|
||||
@wraps(torch.eye)
|
||||
def torch_eye(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
return original_torch_eye(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_eye(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, *args, **kwargs):
|
||||
if map_location is None:
|
||||
map_location = "xpu"
|
||||
if check_device(map_location):
|
||||
if map_location is None or check_cuda(map_location):
|
||||
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_Generator(device)
|
||||
|
||||
@wraps(torch.cuda.synchronize)
|
||||
def torch_cuda_synchronize(device=None):
|
||||
if check_device(device):
|
||||
if check_cuda(device):
|
||||
return torch.xpu.synchronize(return_xpu(device))
|
||||
else:
|
||||
return torch.xpu.synchronize(device)
|
||||
|
||||
@wraps(torch.cuda.device)
|
||||
def torch_cuda_device(device):
|
||||
if check_cuda(device):
|
||||
return torch.xpu.device(return_xpu(device))
|
||||
else:
|
||||
return torch.xpu.device(device)
|
||||
|
||||
@wraps(torch.cuda.set_device)
|
||||
def torch_cuda_set_device(device):
|
||||
if check_cuda(device):
|
||||
torch.xpu.set_device(return_xpu(device))
|
||||
else:
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
# torch.Generator has to be a class for isinstance checks
|
||||
original_torch_Generator = torch.Generator
|
||||
class torch_Generator(original_torch_Generator):
|
||||
def __new__(self, device=None):
|
||||
# can't hijack __init__ because of C override so use return super().__new__
|
||||
if check_cuda(device):
|
||||
return super().__new__(self, return_xpu(device))
|
||||
else:
|
||||
return super().__new__(self, device)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks(legacy=True):
|
||||
global device_supports_fp64, can_allocate_plus_4gb
|
||||
if legacy and float(torch.__version__[:3]) < 2.5:
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
def ipex_hijacks():
|
||||
global device_supports_fp64
|
||||
if torch_version >= 2.4:
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.UntypedStorage.to = UntypedStorage_to
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.Tensor.pin_memory = Tensor_pin_memory
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.empty = torch_empty
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.full = torch_full
|
||||
torch.linspace = torch_linspace
|
||||
torch.eye = torch_eye
|
||||
torch.load = torch_load
|
||||
torch.Generator = torch_Generator
|
||||
torch.cuda.synchronize = torch_cuda_synchronize
|
||||
torch.cuda.device = torch_cuda_device
|
||||
torch.cuda.set_device = torch_cuda_set_device
|
||||
|
||||
torch.Generator = torch_Generator
|
||||
torch._C.Generator = torch_Generator
|
||||
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
||||
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
torch.nn.functional.layer_norm = functional_layer_norm
|
||||
@@ -364,4 +439,28 @@ def ipex_hijacks(legacy=True):
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
return device_supports_fp64, can_allocate_plus_4gb
|
||||
|
||||
# AMP:
|
||||
torch.amp.grad_scaler.GradScaler.__init__ = GradScaler_init
|
||||
torch.is_autocast_enabled = torch_is_autocast_enabled
|
||||
torch.get_autocast_gpu_dtype = torch_get_autocast_dtype
|
||||
torch.get_autocast_dtype = torch_get_autocast_dtype
|
||||
|
||||
if hasattr(torch.xpu, "amp"):
|
||||
if not hasattr(torch.xpu.amp, "custom_fwd"):
|
||||
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
if not hasattr(torch.xpu.amp, "GradScaler"):
|
||||
torch.xpu.amp.GradScaler = torch.amp.grad_scaler.GradScaler
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
else:
|
||||
if not hasattr(torch.amp, "custom_fwd"):
|
||||
torch.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
torch.cuda.amp = torch.amp
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
return device_supports_fp64
|
||||
|
||||
1392
library/lumina_models.py
Normal file
1392
library/lumina_models.py
Normal file
File diff suppressed because it is too large
Load Diff
1098
library/lumina_train_util.py
Normal file
1098
library/lumina_train_util.py
Normal file
File diff suppressed because it is too large
Load Diff
233
library/lumina_util.py
Normal file
233
library/lumina_util.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers import Gemma2Config, Gemma2Model
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library import lumina_models, flux_models
|
||||
from library.utils import load_safetensors
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_VERSION_LUMINA_V2 = "lumina2"
|
||||
|
||||
|
||||
def load_lumina_model(
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: torch.device,
|
||||
disable_mmap: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
use_sage_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
Load the Lumina model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (torch.device): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
|
||||
|
||||
Returns:
|
||||
model (lumina_models.NextDiT): The loaded model.
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
logger.info(f"Loaded Lumina: {info}")
|
||||
return model
|
||||
|
||||
|
||||
def load_ae(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
) -> flux_models.AutoEncoder:
|
||||
"""
|
||||
Load the AutoEncoder model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ae (flux_models.AutoEncoder): The loaded model.
|
||||
"""
|
||||
logger.info("Building AutoEncoder")
|
||||
with torch.device("meta"):
|
||||
# dev and schnell have the same AE params
|
||||
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
|
||||
|
||||
def load_gemma2(
|
||||
ckpt_path: Optional[str],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> Gemma2Model:
|
||||
"""
|
||||
Load the Gemma2 model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
|
||||
|
||||
Returns:
|
||||
gemma2 (Gemma2Model): The loaded model
|
||||
"""
|
||||
logger.info("Building Gemma2")
|
||||
GEMMA2_CONFIG = {
|
||||
"_name_or_path": "google/gemma-2-2b",
|
||||
"architectures": ["Gemma2Model"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"bos_token_id": 2,
|
||||
"cache_implementation": "hybrid",
|
||||
"eos_token_id": 1,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"head_dim": 256,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2304,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9216,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "gemma2",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 26,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 0,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.44.2",
|
||||
"use_cache": True,
|
||||
"vocab_size": 256000,
|
||||
}
|
||||
|
||||
config = Gemma2Config(**GEMMA2_CONFIG)
|
||||
with init_empty_weights():
|
||||
gemma2 = Gemma2Model._from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = gemma2.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
|
||||
# Embedding layers
|
||||
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
|
||||
"time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
|
||||
"text_embedder.1.bias": "cap_embedder.1.bias",
|
||||
"patch_embedder.proj.weight": "x_embedder.weight",
|
||||
"patch_embedder.proj.bias": "x_embedder.bias",
|
||||
# Attention modulation
|
||||
"transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
|
||||
"transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
|
||||
# Final layers
|
||||
"final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
|
||||
"final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
|
||||
"final_linear.weight": "final_layer.linear.weight",
|
||||
"final_linear.bias": "final_layer.linear.bias",
|
||||
# Noise refiner
|
||||
"single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
|
||||
"single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
|
||||
"single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
|
||||
"single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
|
||||
# Normalization
|
||||
"transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
|
||||
"transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
|
||||
# FFN
|
||||
"transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
|
||||
"transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
|
||||
"transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
|
||||
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
|
||||
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
|
||||
new_sd = sd.copy() # Preserve original keys
|
||||
|
||||
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
|
||||
# Handle block-specific patterns
|
||||
if '().' in diff_key:
|
||||
for block_idx in range(num_double_blocks):
|
||||
block_alpha_key = alpha_key.replace('().', f'{block_idx}.')
|
||||
block_diff_key = diff_key.replace('().', f'{block_idx}.')
|
||||
|
||||
# Search for and convert block-specific keys
|
||||
for input_key, value in list(sd.items()):
|
||||
if input_key == block_diff_key:
|
||||
new_sd[block_alpha_key] = value
|
||||
else:
|
||||
# Handle static keys
|
||||
if diff_key in sd:
|
||||
print(f"Replacing {diff_key} with {alpha_key}")
|
||||
new_sd[alpha_key] = sd[diff_key]
|
||||
else:
|
||||
print(f"Not found: {diff_key}")
|
||||
|
||||
|
||||
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
|
||||
return new_sd
|
||||
@@ -61,6 +61,8 @@ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
||||
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -123,6 +126,7 @@ def build_metadata(
|
||||
clip_skip: Optional[int] = None,
|
||||
sd3: Optional[str] = None,
|
||||
flux: Optional[str] = None,
|
||||
lumina: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
sd3: only supports "m", flux: only supports "dev"
|
||||
@@ -146,6 +150,11 @@ def build_metadata(
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif lumina is not None:
|
||||
if lumina == "lumina2":
|
||||
arch = ARCH_LUMINA_2
|
||||
else:
|
||||
arch = ARCH_LUMINA_UNKNOWN
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
@@ -167,6 +176,9 @@ def build_metadata(
|
||||
if flux is not None:
|
||||
# Flux
|
||||
impl = IMPL_FLUX
|
||||
elif lumina is not None:
|
||||
# Lumina
|
||||
impl = IMPL_LUMINA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
@@ -225,7 +237,7 @@ def build_metadata(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl or sd3 is not None or flux is not None:
|
||||
if sdxl or sd3 is not None or flux is not None or lumina is not None:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
|
||||
@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
|
||||
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
||||
|
||||
self.offloader = custom_offloading_utils.ModelOffloader(
|
||||
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
|
||||
self.joint_blocks, self.blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
||||
|
||||
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
save_blocks = self.joint_blocks
|
||||
self.joint_blocks = None
|
||||
self.joint_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -430,9 +430,21 @@ class LatentsCachingStrategy:
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
apply_alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
latents_stride: stride of latents
|
||||
bucket_reso: resolution of the bucket
|
||||
npz_path: path to the npz file
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
@@ -451,7 +463,7 @@ class LatentsCachingStrategy:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
@@ -462,22 +474,35 @@ class LatentsCachingStrategy:
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self,
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
encode_by_vae: Callable,
|
||||
vae_device: torch.device,
|
||||
vae_dtype: torch.dtype,
|
||||
image_infos: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
apply_alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
|
||||
Args:
|
||||
encode_by_vae: function to encode images by VAE
|
||||
vae_device: device to use for VAE
|
||||
vae_dtype: dtype to use for VAE
|
||||
image_infos: list of ImageInfo
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
random_crop: whether to random crop images
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
image_infos, apply_alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||
|
||||
@@ -519,12 +544,40 @@ class LatentsCachingStrategy:
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Args:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
@@ -552,6 +605,19 @@ class LatentsCachingStrategy:
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
latents_tensor (torch.Tensor): Latent tensor
|
||||
original_size (List[int]): Original size of the image
|
||||
crop_ltrb (List[int]): Crop left top right bottom
|
||||
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
|
||||
alpha_mask (Optional[torch.Tensor]): Alpha mask
|
||||
key_reso_suffix (str): Key resolution suffix
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
kwargs = {}
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
|
||||
375
library/strategy_lumina.py
Normal file
375
library/strategy_lumina.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
|
||||
from library import train_util
|
||||
from library.strategy_base import (
|
||||
LatentsCachingStrategy,
|
||||
TokenizeStrategy,
|
||||
TextEncodingStrategy,
|
||||
TextEncoderOutputsCachingStrategy,
|
||||
)
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GEMMA_ID = "google/gemma-2-2b"
|
||||
|
||||
|
||||
class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(
|
||||
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
|
||||
) -> None:
|
||||
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
|
||||
GEMMA_ID, cache_dir=tokenizer_cache_dir
|
||||
)
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = 256
|
||||
else:
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize(
|
||||
self, text: Union[str, List[str]], is_negative: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
text (Union[str, List[str]]): Text to tokenize
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
token input ids, attention_masks
|
||||
"""
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
# In training, we always add system prompt (is_negative=False)
|
||||
if not is_negative:
|
||||
# Add system prompt to the beginning of each text
|
||||
text = [self.system_prompt + t for t in text]
|
||||
|
||||
encodings = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
pad_to_multiple_of=8,
|
||||
)
|
||||
return (encodings.input_ids, encodings.attention_mask)
|
||||
|
||||
def tokenize_with_weights(
|
||||
self, text: str | List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text (Union[str, List[str]]): Text to tokenize
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
token input ids, attention_masks, weights
|
||||
"""
|
||||
# Gemma doesn't support weighted prompts, return uniform weights
|
||||
tokens, attention_masks = self.tokenize(text)
|
||||
weights = [torch.ones_like(t) for t in tokens]
|
||||
return tokens, attention_masks, weights
|
||||
|
||||
|
||||
class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states, input_ids, attention_masks
|
||||
"""
|
||||
text_encoder = models[0]
|
||||
# Check model or torch dynamo OptimizedModule
|
||||
assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}"
|
||||
input_ids, attention_masks = tokens
|
||||
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids.to(text_encoder.device),
|
||||
attention_mask=attention_masks.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return outputs.hidden_states[-2], input_ids, attention_masks
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: Tuple[torch.Tensor, torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
|
||||
weights_list (List[torch.Tensor]): Currently unused
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states, input_ids, attention_masks
|
||||
"""
|
||||
# For simplicity, use uniform weighting
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
|
||||
|
||||
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
is_partial,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return (
|
||||
os.path.splitext(image_abs_path)[0]
|
||||
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
|
||||
Returns:
|
||||
bool: True if the npz file is expected to be cached.
|
||||
"""
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state" not in npz:
|
||||
return False
|
||||
if "attention_mask" not in npz:
|
||||
return False
|
||||
if "input_ids" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Load outputs from a npz file
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: hidden_state, input_ids, attention_mask
|
||||
"""
|
||||
data = np.load(npz_path)
|
||||
hidden_state = data["hidden_state"]
|
||||
attention_mask = data["attention_mask"]
|
||||
input_ids = data["input_ids"]
|
||||
return [hidden_state, input_ids, attention_mask]
|
||||
|
||||
@torch.no_grad()
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: List[train_util.ImageInfo],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
text_encoding_strategy (LuminaTextEncodingStrategy):
|
||||
infos (List): List of ImageInfo
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
|
||||
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
|
||||
|
||||
captions = [info.caption for info in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens, attention_masks, weights_list = (
|
||||
tokenize_strategy.tokenize_with_weights(captions)
|
||||
)
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
(tokens, attention_masks),
|
||||
weights_list,
|
||||
)
|
||||
)
|
||||
else:
|
||||
tokens = tokenize_strategy.tokenize(captions)
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)
|
||||
)
|
||||
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
|
||||
hidden_state = hidden_state.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy() # (B, S)
|
||||
input_ids = input_ids.cpu().numpy() # (B, S)
|
||||
|
||||
|
||||
for i, info in enumerate(batch):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [
|
||||
hidden_state_i,
|
||||
input_ids_i,
|
||||
attention_mask_i,
|
||||
]
|
||||
|
||||
|
||||
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(
|
||||
self, absolute_path: str, image_size: Tuple[int, int]
|
||||
) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
npz_path (str): Path to the npz file.
|
||||
flip_aug (bool): Whether to flip the image.
|
||||
alpha_mask (bool): Whether to apply
|
||||
"""
|
||||
return self._default_is_disk_cached_latents_expected(
|
||||
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
||||
)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
|
||||
"""
|
||||
return self._default_load_latents_from_disk(
|
||||
8, npz_path, bucket_reso
|
||||
) # support multi-resolution
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(
|
||||
self,
|
||||
model,
|
||||
batch: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
):
|
||||
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
|
||||
vae_device = model.device
|
||||
vae_dtype = model.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
batch,
|
||||
flip_aug,
|
||||
alpha_mask,
|
||||
random_crop,
|
||||
multi_resolution=True,
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(model.device)
|
||||
@@ -1060,8 +1060,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
else:
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
@@ -3480,6 +3483,7 @@ def get_sai_model_spec(
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -3515,6 +3519,7 @@ def get_sai_model_spec(
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
sd3=sd3,
|
||||
flux=flux,
|
||||
lumina=lumina,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@@ -5495,6 +5500,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
from accelerate import DistributedType
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
@@ -6178,6 +6188,11 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # guidance scale
|
||||
prompt_dict["guidance_scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
prompt_dict["negative_prompt"] = m.group(1)
|
||||
@@ -6193,6 +6208,17 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["controlnet_image"] = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"ctr (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["cfg_trunc_ratio"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"rcfg (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["renorm_cfg"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(ex)
|
||||
|
||||
415
lumina_minimal_inference.py
Normal file
415
lumina_minimal_inference.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# Minimum Inference Code for Lumina
|
||||
# Based on flux_minimal_inference.py
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma2Model
|
||||
from library.flux_models import AutoEncoder
|
||||
|
||||
from library import (
|
||||
device_utils,
|
||||
lumina_models,
|
||||
lumina_train_util,
|
||||
lumina_util,
|
||||
sd3_train_utils,
|
||||
strategy_lumina,
|
||||
)
|
||||
import networks.lora_lumina as lora_lumina
|
||||
from library.device_utils import get_preferred_device, init_ipex
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
|
||||
init_ipex()
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: lumina_models.NextDiT,
|
||||
gemma2: Gemma2Model,
|
||||
ae: AutoEncoder,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
seed: Optional[int],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
negative_prompt: Optional[str],
|
||||
args,
|
||||
cfg_trunc_ratio: float = 0.25,
|
||||
renorm_cfg: float = 1.0,
|
||||
):
|
||||
#
|
||||
# 0. Prepare arguments
|
||||
#
|
||||
device = get_preferred_device()
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
|
||||
dtype = str_to_dtype(args.dtype)
|
||||
ae_dtype = str_to_dtype(args.ae_dtype)
|
||||
gemma2_dtype = str_to_dtype(args.gemma2_dtype)
|
||||
|
||||
#
|
||||
# 1. Prepare models
|
||||
#
|
||||
# model.to(device, dtype=dtype)
|
||||
model.to(dtype)
|
||||
model.eval()
|
||||
|
||||
gemma2.to(device, dtype=gemma2_dtype)
|
||||
gemma2.eval()
|
||||
|
||||
ae.to(ae_dtype)
|
||||
ae.eval()
|
||||
|
||||
#
|
||||
# 2. Encode prompts
|
||||
#
|
||||
logger.info("Encoding prompts...")
|
||||
|
||||
tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length)
|
||||
encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
with torch.no_grad():
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
with torch.no_grad():
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
prompt_hidden_states, _, prompt_attention_mask = gemma2_conds
|
||||
uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds
|
||||
|
||||
if args.offload:
|
||||
print("Offloading models to CPU to save VRAM...")
|
||||
gemma2.to("cpu")
|
||||
device_utils.clean_memory()
|
||||
|
||||
model.to(device)
|
||||
|
||||
#
|
||||
# 3. Prepare latents
|
||||
#
|
||||
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Seed: {seed}")
|
||||
torch.manual_seed(seed)
|
||||
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
latent_channels = 16
|
||||
|
||||
latents = torch.randn(
|
||||
(1, latent_channels, latent_height, latent_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
#
|
||||
# 4. Denoise
|
||||
#
|
||||
logger.info("Denoising...")
|
||||
scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
scheduler.set_timesteps(steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
# # compare with lumina_train_util.retrieve_timesteps
|
||||
# lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps)
|
||||
# print(f"Using timesteps: {timesteps}")
|
||||
# print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad():
|
||||
latents = lumina_train_util.denoise(
|
||||
scheduler,
|
||||
model,
|
||||
latents.to(device),
|
||||
prompt_hidden_states.to(device),
|
||||
prompt_attention_mask.to(device),
|
||||
uncond_hidden_states.to(device),
|
||||
uncond_attention_mask.to(device),
|
||||
timesteps,
|
||||
guidance_scale,
|
||||
cfg_trunc_ratio,
|
||||
renorm_cfg,
|
||||
)
|
||||
|
||||
if args.offload:
|
||||
model.to("cpu")
|
||||
device_utils.clean_memory()
|
||||
ae.to(device)
|
||||
|
||||
#
|
||||
# 5. Decode latents
|
||||
#
|
||||
logger.info("Decoding image...")
|
||||
latents = latents / ae.scale_factor + ae.shift_factor
|
||||
with torch.no_grad():
|
||||
image = ae.decode(latents.to(ae_dtype))
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
|
||||
#
|
||||
# 6. Save image
|
||||
#
|
||||
pil_image = Image.fromarray(image[0])
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
seed_suffix = f"_{seed}"
|
||||
output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png")
|
||||
pil_image.save(output_path)
|
||||
logger.info(f"Image saved to {output_path}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Lumina DiT model path / Lumina DiTモデルのパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gemma2_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Gemma2 model path / Gemma2モデルのパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Autoencoder model path / Autoencoderモデルのパス",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation")
|
||||
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty")
|
||||
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
parser.add_argument("--steps", type=int, default=36, help="Number of inference steps")
|
||||
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance")
|
||||
parser.add_argument("--image_width", type=int, default=1024, help="Image width")
|
||||
parser.add_argument("--image_height", type=int, default=1024, help="Image height")
|
||||
parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)")
|
||||
parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)")
|
||||
parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)")
|
||||
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')")
|
||||
parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM")
|
||||
parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model")
|
||||
parser.add_argument(
|
||||
"--gemma2_max_token_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Max token length for Gemma2 tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=6.0,
|
||||
help="Shift value for FlowMatchEulerDiscreteScheduler",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_trunc_ratio",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="TBD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--renorm_cfg",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="TBD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
help="Use flash attention for Lumina model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_sage_attn",
|
||||
action="store_true",
|
||||
help="Use sage attention for Lumina model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Loading models...")
|
||||
device = get_preferred_device()
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# Load Lumina DiT model
|
||||
model = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
dtype=None, # Load in fp32 and then convert
|
||||
device="cpu",
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
use_sage_attn=args.use_sage_attn,
|
||||
)
|
||||
|
||||
# Load Gemma2
|
||||
gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu")
|
||||
|
||||
# Load Autoencoder
|
||||
ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")
|
||||
|
||||
# LoRA
|
||||
lora_models = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([gemma2], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([gemma2], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.to(device)
|
||||
lora_model.set_multiplier(multiplier)
|
||||
lora_model.eval()
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
args.prompt,
|
||||
args.system_prompt,
|
||||
args.seed,
|
||||
args.image_width,
|
||||
args.image_height,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
args.negative_prompt,
|
||||
args,
|
||||
args.cfg_trunc_ratio,
|
||||
args.renorm_cfg,
|
||||
)
|
||||
else:
|
||||
# Interactive mode loop
|
||||
image_width = args.image_width
|
||||
image_height = args.image_height
|
||||
steps = args.steps
|
||||
guidance_scale = args.guidance_scale
|
||||
cfg_trunc_ratio = args.cfg_trunc_ratio
|
||||
renorm_cfg = args.renorm_cfg
|
||||
|
||||
print("Entering interactive mode.")
|
||||
while True:
|
||||
print(
|
||||
"\nEnter prompt (or 'exit'). Options: --w <int> --h <int> --s <int> --d <int> --g <float> --n <str> --ctr <float> --rcfg <float> --m <m1,m2...>"
|
||||
)
|
||||
user_input = input()
|
||||
if user_input.lower() == "exit":
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Parse options
|
||||
options = user_input.split("--")
|
||||
prompt = options[0].strip()
|
||||
|
||||
# Set defaults for each generation
|
||||
seed = None # New random seed each time unless specified
|
||||
negative_prompt = args.negative_prompt # Reset to default
|
||||
|
||||
for opt in options[1:]:
|
||||
try:
|
||||
opt = opt.strip()
|
||||
if not opt:
|
||||
continue
|
||||
|
||||
key, value = (opt.split(None, 1) + [""])[:2]
|
||||
|
||||
if key == "w":
|
||||
image_width = int(value)
|
||||
elif key == "h":
|
||||
image_height = int(value)
|
||||
elif key == "s":
|
||||
steps = int(value)
|
||||
elif key == "d":
|
||||
seed = int(value)
|
||||
elif key == "g":
|
||||
guidance_scale = float(value)
|
||||
elif key == "n":
|
||||
negative_prompt = value if value != "-" else ""
|
||||
elif key == "ctr":
|
||||
cfg_trunc_ratio = float(value)
|
||||
elif key == "rcfg":
|
||||
renorm_cfg = float(value)
|
||||
elif key == "m":
|
||||
multipliers = value.split(",")
|
||||
if len(multipliers) != len(lora_models):
|
||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
continue
|
||||
for i, lora_model in enumerate(lora_models):
|
||||
lora_model.set_multiplier(float(multipliers[i].strip()))
|
||||
else:
|
||||
logger.warning(f"Unknown option: --{key}")
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")
|
||||
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
prompt,
|
||||
args.system_prompt,
|
||||
seed,
|
||||
image_width,
|
||||
image_height,
|
||||
steps,
|
||||
guidance_scale,
|
||||
negative_prompt,
|
||||
args,
|
||||
cfg_trunc_ratio,
|
||||
renorm_cfg,
|
||||
)
|
||||
|
||||
logger.info("Done.")
|
||||
953
lumina_train.py
Normal file
953
lumina_train.py
Normal file
@@ -0,0 +1,953 @@
|
||||
# training with captions
|
||||
|
||||
# Swap blocks between CPU and GPU:
|
||||
# This implementation is inspired by and based on the work of 2kpr.
|
||||
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
|
||||
# The original idea has been adapted and extended to fit the current project's needs.
|
||||
|
||||
# Key features:
|
||||
# - CPU offloading during forward and backward passes
|
||||
# - Use of fused optimizer and grad_hook for efficient gradient processing
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import (
|
||||
deepspeed_utils,
|
||||
lumina_train_util,
|
||||
lumina_util,
|
||||
strategy_base,
|
||||
strategy_lumina,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
# import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# temporary: backward compatibility for deprecated options. remove in the future
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning(
|
||||
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
|
||||
)
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
# assert (
|
||||
# args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
# ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(
|
||||
ConfigSanitizer(True, True, args.masked_loss, True)
|
||||
)
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = (
|
||||
config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = (
|
||||
train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
)
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(
|
||||
strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
|
||||
)
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = lumina_util.load_ae(
|
||||
args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare tokenize strategy
|
||||
if args.gemma2_max_token_length is None:
|
||||
gemma2_max_token_length = 256
|
||||
else:
|
||||
gemma2_max_token_length = args.gemma2_max_token_length
|
||||
|
||||
lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
|
||||
args.system_prompt, gemma2_max_token_length
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)
|
||||
|
||||
# load gemma2 for caching text encoder outputs
|
||||
gemma2 = lumina_util.load_gemma2(
|
||||
args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
gemma2.eval()
|
||||
gemma2.requires_grad_(False)
|
||||
|
||||
text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad here
|
||||
gemma2.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = (
|
||||
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
text_encoder_caching_strategy
|
||||
)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
|
||||
)
|
||||
|
||||
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
|
||||
strategy_base.TextEncodingStrategy.get_strategy()
|
||||
)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for i, p in enumerate([
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]):
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
|
||||
sample_prompts_te_outputs[p] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
lumina_tokenize_strategy,
|
||||
[gemma2],
|
||||
tokens_and_masks,
|
||||
)
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# now we can delete Text Encoders to free memory
|
||||
gemma2 = None
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load lumina
|
||||
nextdit = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
nextdit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing
|
||||
)
|
||||
|
||||
nextdit.requires_grad_(True)
|
||||
|
||||
# block swap
|
||||
|
||||
# backward compatibility
|
||||
# if args.blocks_to_swap is None:
|
||||
# blocks_to_swap = args.double_blocks_to_swap or 0
|
||||
# if args.single_blocks_to_swap is not None:
|
||||
# blocks_to_swap += args.single_blocks_to_swap // 2
|
||||
# if blocks_to_swap > 0:
|
||||
# logger.warning(
|
||||
# "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
||||
# " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
||||
# )
|
||||
# logger.info(
|
||||
# f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
||||
# )
|
||||
# args.blocks_to_swap = blocks_to_swap
|
||||
# del blocks_to_swap
|
||||
|
||||
# is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
# if is_swapping_blocks:
|
||||
# # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# # This idea is based on 2kpr's great work. Thank you!
|
||||
# logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
# flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(nextdit)
|
||||
name_and_params = list(nextdit.named_parameters())
|
||||
# single param group for now
|
||||
params_to_optimize.append(
|
||||
{"params": [p for _, p in name_and_params], "lr": args.learning_rate}
|
||||
)
|
||||
param_names = [[n for n, _ in name_and_params]]
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for group in params_to_optimize:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(
|
||||
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||
)
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError(
|
||||
"Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||
)
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
args, trainable_params=params_to_optimize
|
||||
)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
optimizer, args
|
||||
)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(
|
||||
args.max_data_loader_n_workers, os.cpu_count()
|
||||
) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader)
|
||||
/ accelerator.num_processes
|
||||
/ args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
if args.blockwise_fused_optimizers:
|
||||
# prepare lr schedulers for each optimizer
|
||||
lr_schedulers = [
|
||||
train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
for optimizer in optimizers
|
||||
]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args, optimizer, accelerator.num_processes
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
nextdit.to(weight_dtype)
|
||||
if gemma2 is not None:
|
||||
gemma2.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
nextdit.to(weight_dtype)
|
||||
if gemma2 is not None:
|
||||
gemma2.to(weight_dtype)
|
||||
|
||||
# if we don't cache text encoder outputs, move them to device
|
||||
if not args.cache_text_encoder_outputs:
|
||||
gemma2.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
# accelerator does some magic
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit = accelerator.prepare(
|
||||
nextdit, device_placement=[not is_swapping_blocks]
|
||||
)
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
|
||||
accelerator.device
|
||||
) # reduce peak memory usage
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(
|
||||
create_grad_hook(param_name, param_group)
|
||||
)
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# counters are used to determine when to step the optimizer
|
||||
global optimizer_hooked_count
|
||||
global num_parameters_per_group
|
||||
global parameter_optimizer_map
|
||||
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(
|
||||
parameter, args.max_grad_norm
|
||||
)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = (
|
||||
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
)
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(
|
||||
f" num examples / サンプル数: {train_dataset_group.num_train_images}"
|
||||
)
|
||||
accelerator.print(
|
||||
f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}"
|
||||
)
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(
|
||||
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
accelerator.print(
|
||||
f" total optimization steps / 学習ステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="steps",
|
||||
)
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000, shift=args.discrete_flow_shift
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0 # avoid error when max_train_steps is 0
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {
|
||||
i: 0 for i in range(len(optimizers))
|
||||
} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
else:
|
||||
# not cached or training, so get from text encoders
|
||||
tokens_and_masks = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids = [
|
||||
ids.to(accelerator.device)
|
||||
for ids in batch["input_ids_list"]
|
||||
]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
lumina_tokenize_strategy,
|
||||
[gemma2],
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [
|
||||
c.to(weight_dtype) for c in text_encoder_conds
|
||||
]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = (
|
||||
lumina_train_util.get_noisy_model_input_and_timesteps(
|
||||
args,
|
||||
noise_scheduler_copy,
|
||||
latents,
|
||||
noise,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
)
|
||||
# call model
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = nextdit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
dtype=torch.int32
|
||||
), # Gemma2的attention mask
|
||||
)
|
||||
# apply model prediction type
|
||||
model_pred, weighting = lumina_train_util.apply_model_prediction_type(
|
||||
args, model_pred, noisy_model_input, sigmas
|
||||
)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(
|
||||
args, timesteps, noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or (
|
||||
"alpha_masks" in batch and batch["alpha_masks"] is not None
|
||||
):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if (
|
||||
args.save_every_n_steps is not None
|
||||
and global_step % args.save_every_n_steps == 0
|
||||
):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(nextdit),
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(
|
||||
logs, lr_scheduler, args.optimizer_type, including_unet=True
|
||||
)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(nextdit),
|
||||
)
|
||||
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
nextdit = accelerator.unwrap_model(nextdit)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
lumina_train_util.save_lumina_model_on_train_end(
|
||||
args, save_dtype, epoch, global_step, nextdit
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser) # TODO remove this from here
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
lumina_train_util.add_lumina_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--mem_eff_save",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fused_optimizer_groups",
|
||||
type=int,
|
||||
default=None,
|
||||
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
383
lumina_train_network.py
Normal file
383
lumina_train_network.py
Normal file
@@ -0,0 +1,383 @@
|
||||
import argparse
|
||||
import copy
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch import Tensor
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
lumina_models,
|
||||
lumina_util,
|
||||
lumina_train_util,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_lumina,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
self.train_gemma2 = not args.network_train_unet_only
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
model = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
use_sage_attn=args.use_sage_attn,
|
||||
)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 Lumina 2 model")
|
||||
else:
|
||||
logger.info(
|
||||
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||||
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||||
)
|
||||
model.to(torch.float8_e4m3fn)
|
||||
|
||||
if args.blocks_to_swap:
|
||||
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
self.is_swapping_blocks = True
|
||||
|
||||
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
|
||||
gemma2.eval()
|
||||
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
||||
|
||||
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_lumina.LuminaTextEncodingStrategy()
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [self.train_gemma2]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_gemma2,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self,
|
||||
args,
|
||||
accelerator: Accelerator,
|
||||
unet,
|
||||
vae,
|
||||
text_encoders,
|
||||
dataset,
|
||||
weight_dtype,
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in sample_prompts:
|
||||
prompts = [
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt in sample_prompts_te_outputs:
|
||||
continue
|
||||
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
|
||||
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
)
|
||||
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move back to cpu
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move Gemma 2 back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def sample_images(
|
||||
self,
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
lumina,
|
||||
):
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
lumina,
|
||||
vae,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoder),
|
||||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# Remaining methods maintain similar structure to flux implementation
|
||||
# with Lumina-specific model calls and strategies
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
# not sure, they use same flux vae
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
|
||||
dit: lumina_models.NextDiT,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
|
||||
noise = torch.randn_like(latents)
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
||||
|
||||
def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = dit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=noisy_model_input,
|
||||
gemma2_hidden_states=gemma2_hidden_states,
|
||||
gemma2_attn_mask=gemma2_attn_mask,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss
|
||||
target = latents - noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=noisy_model_input[diff_output_pr_indices],
|
||||
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
|
||||
)
|
||||
network.set_multiplier(1.0)
|
||||
|
||||
# model_pred_prior = lumina_util.unpack_latents(
|
||||
# model_pred_prior, packed_latent_height, packed_latent_width
|
||||
# )
|
||||
model_pred_prior, _ = lumina_train_util.apply_model_prediction_type(
|
||||
args,
|
||||
model_pred_prior,
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
text_encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.embed_tokens.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit = unet
|
||||
assert isinstance(nextdit, lumina_models.NextDiT)
|
||||
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
||||
|
||||
return nextdit
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
lumina_train_util.add_lumina_train_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = LuminaNetworkTrainer()
|
||||
trainer.train(args)
|
||||
@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor:
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
|
||||
def weight_norms(self) -> Tensor:
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
|
||||
1038
networks/lora_lumina.py
Normal file
1038
networks/lora_lumina.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -640,23 +640,14 @@ def train(args):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size]
|
||||
for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
list_latents.append(
|
||||
vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
|
||||
295
tests/library/test_lumina_models.py
Normal file
295
tests/library/test_lumina_models.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.lumina_models import (
|
||||
LuminaParams,
|
||||
to_cuda,
|
||||
to_cpu,
|
||||
RopeEmbedder,
|
||||
TimestepEmbedder,
|
||||
modulate,
|
||||
NextDiT,
|
||||
)
|
||||
|
||||
cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
|
||||
def test_lumina_params():
|
||||
# Test default configuration
|
||||
default_params = LuminaParams()
|
||||
assert default_params.patch_size == 2
|
||||
assert default_params.in_channels == 4
|
||||
assert default_params.axes_dims == [36, 36, 36]
|
||||
assert default_params.axes_lens == [300, 512, 512]
|
||||
|
||||
# Test 2B config
|
||||
config_2b = LuminaParams.get_2b_config()
|
||||
assert config_2b.dim == 2304
|
||||
assert config_2b.in_channels == 16
|
||||
assert config_2b.n_layers == 26
|
||||
assert config_2b.n_heads == 24
|
||||
assert config_2b.cap_feat_dim == 2304
|
||||
|
||||
# Test 7B config
|
||||
config_7b = LuminaParams.get_7b_config()
|
||||
assert config_7b.dim == 4096
|
||||
assert config_7b.n_layers == 32
|
||||
assert config_7b.n_heads == 32
|
||||
assert config_7b.axes_dims == [64, 64, 64]
|
||||
|
||||
|
||||
@cuda_required
|
||||
def test_to_cuda_to_cpu():
|
||||
# Test tensor conversion
|
||||
x = torch.tensor([1, 2, 3])
|
||||
x_cuda = to_cuda(x)
|
||||
x_cpu = to_cpu(x_cuda)
|
||||
assert x.cpu().tolist() == x_cpu.tolist()
|
||||
|
||||
# Test list conversion
|
||||
list_data = [torch.tensor([1]), torch.tensor([2])]
|
||||
list_cuda = to_cuda(list_data)
|
||||
assert all(tensor.device.type == "cuda" for tensor in list_cuda)
|
||||
|
||||
list_cpu = to_cpu(list_cuda)
|
||||
assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
|
||||
|
||||
# Test dict conversion
|
||||
dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
|
||||
dict_cuda = to_cuda(dict_data)
|
||||
assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
|
||||
|
||||
dict_cpu = to_cpu(dict_cuda)
|
||||
assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
|
||||
|
||||
|
||||
def test_timestep_embedder():
|
||||
# Test initialization
|
||||
hidden_size = 256
|
||||
freq_emb_size = 128
|
||||
embedder = TimestepEmbedder(hidden_size, freq_emb_size)
|
||||
assert embedder.frequency_embedding_size == freq_emb_size
|
||||
|
||||
# Test timestep embedding
|
||||
t = torch.tensor([0.5, 1.0, 2.0])
|
||||
emb_dim = freq_emb_size
|
||||
embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
|
||||
|
||||
assert embeddings.shape == (3, emb_dim)
|
||||
assert embeddings.dtype == torch.float32
|
||||
|
||||
# Ensure embeddings are unique for different input times
|
||||
assert not torch.allclose(embeddings[0], embeddings[1])
|
||||
|
||||
# Test forward pass
|
||||
t_emb = embedder(t)
|
||||
assert t_emb.shape == (3, hidden_size)
|
||||
|
||||
|
||||
def test_rope_embedder_simple():
|
||||
rope_embedder = RopeEmbedder()
|
||||
batch_size, seq_len = 2, 10
|
||||
|
||||
# Create position_ids with valid ranges for each axis
|
||||
position_ids = torch.stack(
|
||||
[
|
||||
torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
|
||||
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
|
||||
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
freqs_cis = rope_embedder(position_ids)
|
||||
# RoPE embeddings work in pairs, so output dimension is half of total axes_dims
|
||||
expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
|
||||
assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
|
||||
|
||||
|
||||
def test_modulate():
|
||||
# Test modulation with different scales
|
||||
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
scale = torch.tensor([1.5, 2.0])
|
||||
|
||||
modulated_x = modulate(x, scale)
|
||||
|
||||
# Check that modulation scales correctly
|
||||
# The function does x * (1 + scale), so:
|
||||
# For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
|
||||
expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
|
||||
# Which equals: [[2.5, 5.0], [9.0, 12.0]]
|
||||
|
||||
assert torch.allclose(modulated_x, expected_x)
|
||||
|
||||
|
||||
def test_nextdit_parameter_count_optimized():
|
||||
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
||||
# So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
||||
model_small = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4, # Smaller
|
||||
dim=120, # 120 // 4 = 30
|
||||
n_layers=2, # Much fewer layers
|
||||
n_heads=4, # Fewer heads
|
||||
n_kv_heads=2,
|
||||
axes_dims=[10, 10, 10], # sum = 30
|
||||
axes_lens=[10, 32, 32], # Smaller
|
||||
)
|
||||
param_count_small = model_small.parameter_count()
|
||||
assert param_count_small > 0
|
||||
|
||||
# For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
|
||||
model_medium = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=192, # 192 // 6 = 32
|
||||
n_layers=4, # More layers
|
||||
n_heads=6,
|
||||
n_kv_heads=3,
|
||||
axes_dims=[10, 11, 11], # sum = 32
|
||||
axes_lens=[10, 32, 32],
|
||||
)
|
||||
param_count_medium = model_medium.parameter_count()
|
||||
assert param_count_medium > param_count_small
|
||||
print(f"Small model: {param_count_small:,} parameters")
|
||||
print(f"Medium model: {param_count_medium:,} parameters")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_precompute_freqs_cis():
|
||||
# Test precompute_freqs_cis
|
||||
dim = [16, 56, 56]
|
||||
end = [1, 512, 512]
|
||||
theta = 10000.0
|
||||
|
||||
freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
|
||||
|
||||
# Check number of frequency tensors
|
||||
assert len(freqs_cis) == len(dim)
|
||||
|
||||
# Check each frequency tensor
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
assert freqs_cis[i].shape == (e, d // 2)
|
||||
assert freqs_cis[i].dtype == torch.complex128
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_nextdit_patchify_and_embed():
|
||||
"""Test the patchify_and_embed method which is crucial for training"""
|
||||
# Create a small NextDiT model for testing
|
||||
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
||||
# For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
||||
model = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=120, # 120 // 4 = 30
|
||||
n_layers=1, # Minimal layers for faster testing
|
||||
n_refiner_layers=1, # Minimal refiner layers
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
axes_dims=[10, 10, 10], # sum = 30
|
||||
axes_lens=[10, 32, 32],
|
||||
cap_feat_dim=120, # Match dim for consistency
|
||||
)
|
||||
|
||||
# Prepare test inputs
|
||||
batch_size = 2
|
||||
height, width = 64, 64 # Must be divisible by patch_size (2)
|
||||
caption_seq_len = 8
|
||||
|
||||
# Create mock inputs
|
||||
x = torch.randn(batch_size, 4, height, width) # Image latents
|
||||
cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
|
||||
cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
|
||||
# Make second batch have shorter caption
|
||||
cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
|
||||
t = torch.randn(batch_size, 120) # Timestep embeddings
|
||||
|
||||
# Call patchify_and_embed
|
||||
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
image_seq_len = (height // 2) * (width // 2) # patch_size = 2
|
||||
expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
|
||||
max_seq_len = max(expected_seq_lengths)
|
||||
|
||||
# Check joint hidden states shape
|
||||
assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
|
||||
assert joint_hidden_states.dtype == torch.float32
|
||||
|
||||
# Check attention mask shape and values
|
||||
assert attention_mask.shape == (batch_size, max_seq_len)
|
||||
assert attention_mask.dtype == torch.bool
|
||||
# First batch should have all positions valid up to its sequence length
|
||||
assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
|
||||
assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
|
||||
# Second batch should have all positions valid up to its sequence length
|
||||
assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
|
||||
assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
|
||||
|
||||
# Check freqs_cis shape
|
||||
assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
|
||||
|
||||
# Check effective caption lengths
|
||||
assert l_effective_cap_len == [caption_seq_len, 6]
|
||||
|
||||
# Check sequence lengths
|
||||
assert seq_lengths == expected_seq_lengths
|
||||
|
||||
# Validate that the joint hidden states contain non-zero values where attention mask is True
|
||||
for i in range(batch_size):
|
||||
valid_positions = attention_mask[i]
|
||||
# Check that valid positions have meaningful data (not all zeros)
|
||||
valid_data = joint_hidden_states[i][valid_positions]
|
||||
assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
|
||||
|
||||
# Check that invalid positions are zeros
|
||||
if valid_positions.sum() < max_seq_len:
|
||||
invalid_data = joint_hidden_states[i][~valid_positions]
|
||||
assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_nextdit_patchify_and_embed_edge_cases():
|
||||
"""Test edge cases for patchify_and_embed"""
|
||||
# Create minimal model
|
||||
model = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=60, # 60 // 3 = 20
|
||||
n_layers=1,
|
||||
n_refiner_layers=1,
|
||||
n_heads=3,
|
||||
n_kv_heads=1,
|
||||
axes_dims=[8, 6, 6], # sum = 20
|
||||
axes_lens=[10, 16, 16],
|
||||
cap_feat_dim=60,
|
||||
)
|
||||
|
||||
# Test with empty captions (all masked)
|
||||
batch_size = 1
|
||||
height, width = 32, 32
|
||||
caption_seq_len = 4
|
||||
|
||||
x = torch.randn(batch_size, 4, height, width)
|
||||
cap_feats = torch.randn(batch_size, caption_seq_len, 60)
|
||||
cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
|
||||
t = torch.randn(batch_size, 60)
|
||||
|
||||
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
|
||||
# With all captions masked, effective length should be 0
|
||||
assert l_effective_cap_len == [0]
|
||||
|
||||
# Sequence length should just be the image sequence length
|
||||
image_seq_len = (height // 2) * (width // 2)
|
||||
assert seq_lengths == [image_seq_len]
|
||||
|
||||
# Joint hidden states should only contain image data
|
||||
assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
|
||||
assert attention_mask.shape == (batch_size, image_seq_len)
|
||||
assert torch.all(attention_mask[0]) # All image positions should be valid
|
||||
241
tests/library/test_lumina_train_util.py
Normal file
241
tests/library/test_lumina_train_util.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
from library.lumina_train_util import (
|
||||
batchify,
|
||||
time_shift,
|
||||
get_lin_function,
|
||||
get_schedule,
|
||||
compute_density_for_timestep_sampling,
|
||||
get_sigmas,
|
||||
compute_loss_weighting_for_sd3,
|
||||
get_noisy_model_input_and_timesteps,
|
||||
apply_model_prediction_type,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
|
||||
def test_batchify():
|
||||
# Test case with no batch size specified
|
||||
prompts = [
|
||||
{"prompt": "test1"},
|
||||
{"prompt": "test2"},
|
||||
{"prompt": "test3"}
|
||||
]
|
||||
batchified = list(batchify(prompts))
|
||||
assert len(batchified) == 1
|
||||
assert len(batchified[0]) == 3
|
||||
|
||||
# Test case with batch size specified
|
||||
batchified_sized = list(batchify(prompts, batch_size=2))
|
||||
assert len(batchified_sized) == 2
|
||||
assert len(batchified_sized[0]) == 2
|
||||
assert len(batchified_sized[1]) == 1
|
||||
|
||||
# Test batching with prompts having same parameters
|
||||
prompts_with_params = [
|
||||
{"prompt": "test1", "width": 512, "height": 512},
|
||||
{"prompt": "test2", "width": 512, "height": 512},
|
||||
{"prompt": "test3", "width": 1024, "height": 1024}
|
||||
]
|
||||
batchified_params = list(batchify(prompts_with_params))
|
||||
assert len(batchified_params) == 2
|
||||
|
||||
# Test invalid batch size
|
||||
with pytest.raises(ValueError):
|
||||
list(batchify(prompts, batch_size=0))
|
||||
with pytest.raises(ValueError):
|
||||
list(batchify(prompts, batch_size=-1))
|
||||
|
||||
|
||||
def test_time_shift():
|
||||
# Test standard parameters
|
||||
t = torch.tensor([0.5])
|
||||
mu = 1.0
|
||||
sigma = 1.0
|
||||
result = time_shift(mu, sigma, t)
|
||||
assert 0 <= result <= 1
|
||||
|
||||
# Test with edge cases
|
||||
t_edges = torch.tensor([0.0, 1.0])
|
||||
result_edges = time_shift(1.0, 1.0, t_edges)
|
||||
|
||||
# Check that results are bounded within [0, 1]
|
||||
assert torch.all(result_edges >= 0)
|
||||
assert torch.all(result_edges <= 1)
|
||||
|
||||
|
||||
def test_get_lin_function():
|
||||
# Default parameters
|
||||
func = get_lin_function()
|
||||
assert func(256) == 0.5
|
||||
assert func(4096) == 1.15
|
||||
|
||||
# Custom parameters
|
||||
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
|
||||
assert custom_func(100) == 0.1
|
||||
assert custom_func(1000) == 0.9
|
||||
|
||||
|
||||
def test_get_schedule():
|
||||
# Basic schedule
|
||||
schedule = get_schedule(num_steps=10, image_seq_len=256)
|
||||
assert len(schedule) == 10
|
||||
assert all(0 <= x <= 1 for x in schedule)
|
||||
|
||||
# Test different sequence lengths
|
||||
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
|
||||
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
|
||||
assert len(short_schedule) == 5
|
||||
assert len(long_schedule) == 15
|
||||
|
||||
# Test with shift disabled
|
||||
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
|
||||
assert torch.allclose(
|
||||
torch.tensor(unshifted_schedule),
|
||||
torch.linspace(1, 1/10, 10)
|
||||
)
|
||||
|
||||
|
||||
def test_compute_density_for_timestep_sampling():
|
||||
# Test uniform sampling
|
||||
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
|
||||
assert len(uniform_samples) == 100
|
||||
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
|
||||
|
||||
# Test logit normal sampling
|
||||
logit_normal_samples = compute_density_for_timestep_sampling(
|
||||
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
|
||||
)
|
||||
assert len(logit_normal_samples) == 100
|
||||
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
|
||||
|
||||
# Test mode sampling
|
||||
mode_samples = compute_density_for_timestep_sampling(
|
||||
"mode", batch_size=100, mode_scale=0.5
|
||||
)
|
||||
assert len(mode_samples) == 100
|
||||
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
|
||||
|
||||
|
||||
def test_get_sigmas():
|
||||
# Create a mock noise scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Test with default parameters
|
||||
timesteps = torch.tensor([100, 500, 900])
|
||||
sigmas = get_sigmas(scheduler, timesteps, device)
|
||||
|
||||
# Check shape and basic properties
|
||||
assert sigmas.shape[0] == 3
|
||||
assert torch.all(sigmas >= 0)
|
||||
|
||||
# Test with different n_dim
|
||||
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
|
||||
assert sigmas_4d.ndim == 4
|
||||
|
||||
# Test with different dtype
|
||||
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
|
||||
assert sigmas_float16.dtype == torch.float16
|
||||
|
||||
|
||||
def test_compute_loss_weighting_for_sd3():
|
||||
# Prepare some mock sigmas
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
# Test sigma_sqrt weighting
|
||||
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
|
||||
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
|
||||
|
||||
# Test cosmap weighting
|
||||
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
expected_cosmap = 2 / (math.pi * bot)
|
||||
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
|
||||
|
||||
# Test default weighting
|
||||
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
|
||||
assert torch.all(default_weighting == 1)
|
||||
|
||||
|
||||
def test_apply_model_prediction_type():
|
||||
# Create mock args and tensors
|
||||
class MockArgs:
|
||||
model_prediction_type = "raw"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
|
||||
args = MockArgs()
|
||||
model_pred = torch.tensor([1.0, 2.0, 3.0])
|
||||
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
# Test raw prediction type
|
||||
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(raw_pred == model_pred)
|
||||
assert raw_weighting is None
|
||||
|
||||
# Test additive prediction type
|
||||
args.model_prediction_type = "additive"
|
||||
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(additive_pred == model_pred + noisy_model_input)
|
||||
|
||||
# Test sigma scaled prediction type
|
||||
args.model_prediction_type = "sigma_scaled"
|
||||
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
|
||||
assert sigma_weighting is not None
|
||||
|
||||
|
||||
def test_retrieve_timesteps():
|
||||
# Create a mock scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
|
||||
# Test with num_inference_steps
|
||||
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
|
||||
assert len(timesteps) == 50
|
||||
assert n_steps == 50
|
||||
|
||||
# Test error handling with simultaneous timesteps and sigmas
|
||||
with pytest.raises(ValueError):
|
||||
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
def test_get_noisy_model_input_and_timesteps():
|
||||
# Create a mock args and setup
|
||||
class MockArgs:
|
||||
timestep_sampling = "uniform"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
sigmoid_scale = 1.0
|
||||
discrete_flow_shift = 6.0
|
||||
|
||||
args = MockArgs()
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Prepare mock latents and noise
|
||||
latents = torch.randn(4, 16, 64, 64)
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Test uniform sampling
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
|
||||
# Validate output shapes and types
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
assert noisy_input.dtype == torch.float32
|
||||
assert timesteps.dtype == torch.float32
|
||||
|
||||
# Test different sampling methods
|
||||
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
|
||||
for method in sampling_methods:
|
||||
args.timestep_sampling = method
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
112
tests/library/test_lumina_util.py
Normal file
112
tests/library/test_lumina_util.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.nn.modules import conv
|
||||
|
||||
from library import lumina_util
|
||||
|
||||
|
||||
def test_unpack_latents():
|
||||
# Create a test tensor
|
||||
# Shape: [batch, height*width, channels*patch_height*patch_width]
|
||||
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
|
||||
packed_latent_height = 2
|
||||
packed_latent_width = 2
|
||||
|
||||
# Unpack the latents
|
||||
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# Check output shape
|
||||
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
|
||||
assert unpacked.shape == (2, 4, 4, 4)
|
||||
|
||||
|
||||
def test_pack_latents():
|
||||
# Create a test tensor
|
||||
# Shape: [batch, channels, height*patch_height, width*patch_width]
|
||||
x = torch.randn(2, 4, 4, 4)
|
||||
|
||||
# Pack the latents
|
||||
packed = lumina_util.pack_latents(x)
|
||||
|
||||
# Check output shape
|
||||
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
|
||||
assert packed.shape == (2, 4, 16)
|
||||
|
||||
|
||||
def test_convert_diffusers_sd_to_alpha_vllm():
|
||||
num_double_blocks = 2
|
||||
# Predefined test cases based on the actual conversion map
|
||||
test_cases = [
|
||||
# Static key conversions with possible list mappings
|
||||
{
|
||||
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"expected_converted_keys": ["cap_embedder.0.weight"],
|
||||
},
|
||||
{
|
||||
"original_keys": ["patch_embedder.proj.weight"],
|
||||
"original_pattern": ["patch_embedder.proj.weight"],
|
||||
"expected_converted_keys": ["x_embedder.weight"],
|
||||
},
|
||||
{
|
||||
"original_keys": ["transformer_blocks.0.norm1.weight"],
|
||||
"original_pattern": ["transformer_blocks.().norm1.weight"],
|
||||
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for test_case in test_cases:
|
||||
for original_key, original_pattern, expected_converted_key in zip(
|
||||
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
|
||||
):
|
||||
# Create test state dict
|
||||
test_sd = {original_key: torch.randn(10, 10)}
|
||||
|
||||
# Convert the state dict
|
||||
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
||||
|
||||
# Verify conversion (handle both string and list keys)
|
||||
# Find the correct converted key
|
||||
match_found = False
|
||||
if expected_converted_key in converted_sd:
|
||||
# Verify tensor preservation
|
||||
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
|
||||
f"Tensor mismatch for {original_key}"
|
||||
)
|
||||
match_found = True
|
||||
break
|
||||
|
||||
assert match_found, f"Failed to convert {original_key}"
|
||||
|
||||
# Ensure original key is also present
|
||||
assert original_key in converted_sd
|
||||
|
||||
# Test with block-specific keys
|
||||
block_specific_cases = [
|
||||
{
|
||||
"original_pattern": "transformer_blocks.().norm1.weight",
|
||||
"converted_pattern": "layers.().attention_norm1.weight",
|
||||
}
|
||||
]
|
||||
|
||||
for case in block_specific_cases:
|
||||
for block_idx in range(2): # Test multiple block indices
|
||||
# Prepare block-specific keys
|
||||
block_original_key = case["original_pattern"].replace("()", str(block_idx))
|
||||
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
|
||||
print(block_original_key, block_converted_key)
|
||||
|
||||
# Create test state dict
|
||||
test_sd = {block_original_key: torch.randn(10, 10)}
|
||||
|
||||
# Convert the state dict
|
||||
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
||||
|
||||
# Verify conversion
|
||||
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
|
||||
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
|
||||
f"Tensor mismatch for block key {block_original_key}"
|
||||
)
|
||||
|
||||
# Ensure original key is also present
|
||||
assert block_original_key in converted_sd
|
||||
241
tests/library/test_strategy_lumina.py
Normal file
241
tests/library/test_strategy_lumina.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import patch
|
||||
from transformers import Gemma2Model
|
||||
|
||||
from library.strategy_lumina import (
|
||||
LuminaTokenizeStrategy,
|
||||
LuminaTextEncodingStrategy,
|
||||
LuminaTextEncoderOutputsCachingStrategy,
|
||||
LuminaLatentsCachingStrategy,
|
||||
)
|
||||
|
||||
|
||||
class SimpleMockGemma2Model:
|
||||
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
|
||||
|
||||
def __init__(self, hidden_size=2304):
|
||||
self.device = torch.device("cpu")
|
||||
self._hidden_size = hidden_size
|
||||
self._orig_mod = self # For dynamic compilation compatibility
|
||||
|
||||
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
|
||||
# Create a mock output object with hidden states
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self._hidden_size
|
||||
|
||||
class MockOutput:
|
||||
def __init__(self, hidden_states):
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
mock_hidden_states = [
|
||||
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
|
||||
for _ in range(3) # Mimic multiple layers of hidden states
|
||||
]
|
||||
|
||||
return MockOutput(mock_hidden_states)
|
||||
|
||||
|
||||
def test_lumina_tokenize_strategy():
|
||||
# Test default initialization
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
return
|
||||
assert tokenize_strategy.max_length == 256
|
||||
assert tokenize_strategy.tokenizer.padding_side == "right"
|
||||
|
||||
# Test tokenization of a single string
|
||||
text = "Hello"
|
||||
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
||||
|
||||
assert tokens.ndim == 2
|
||||
assert attention_mask.ndim == 2
|
||||
assert tokens.shape == attention_mask.shape
|
||||
assert tokens.shape[1] == 256 # max_length
|
||||
|
||||
# Test tokenize_with_weights
|
||||
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
|
||||
assert len(weights) == 1
|
||||
assert torch.all(weights[0] == 1)
|
||||
|
||||
|
||||
def test_lumina_text_encoding_strategy():
|
||||
# Create strategies
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
return
|
||||
encoding_strategy = LuminaTextEncodingStrategy()
|
||||
|
||||
# Create a mock model
|
||||
mock_model = SimpleMockGemma2Model()
|
||||
|
||||
# Patch the isinstance check to accept our simple mock
|
||||
original_isinstance = isinstance
|
||||
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
||||
|
||||
def custom_isinstance(obj, class_or_tuple):
|
||||
if obj is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
return original_isinstance(obj, class_or_tuple)
|
||||
|
||||
mock_isinstance.side_effect = custom_isinstance
|
||||
|
||||
# Prepare sample text
|
||||
text = "Test encoding strategy"
|
||||
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
||||
|
||||
# Perform encoding
|
||||
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [mock_model], (tokens, attention_mask)
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
assert original_isinstance(hidden_states, torch.Tensor)
|
||||
assert original_isinstance(input_ids, torch.Tensor)
|
||||
assert original_isinstance(attention_masks, torch.Tensor)
|
||||
|
||||
# Check the shape of the second-to-last hidden state
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
# Test weighted encoding (which falls back to standard encoding for Lumina)
|
||||
weights = [torch.ones_like(tokens)]
|
||||
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
|
||||
)
|
||||
|
||||
# For the mock, we can't guarantee identical outputs since each call returns random tensors
|
||||
# Instead, check that the outputs have the same shape and are tensors
|
||||
assert hidden_states_w.shape == hidden_states.shape
|
||||
assert original_isinstance(hidden_states_w, torch.Tensor)
|
||||
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
|
||||
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
|
||||
|
||||
|
||||
def test_lumina_text_encoder_outputs_caching_strategy():
|
||||
# Create a temporary directory for caching
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a cache file path
|
||||
cache_file = os.path.join(tmpdir, "test_outputs.npz")
|
||||
|
||||
# Create the caching strategy
|
||||
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
|
||||
cache_to_disk=True,
|
||||
batch_size=1,
|
||||
skip_disk_cache_validity_check=False,
|
||||
)
|
||||
|
||||
# Create a mock class for ImageInfo
|
||||
class MockImageInfo:
|
||||
def __init__(self, caption, cache_path):
|
||||
self.caption = caption
|
||||
self.text_encoder_outputs_npz = cache_path
|
||||
|
||||
# Create a sample input info
|
||||
image_info = MockImageInfo("Test caption", cache_file)
|
||||
|
||||
# Simulate a batch
|
||||
batch = [image_info]
|
||||
|
||||
# Create mock strategies and model
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
return
|
||||
encoding_strategy = LuminaTextEncodingStrategy()
|
||||
mock_model = SimpleMockGemma2Model()
|
||||
|
||||
# Patch the isinstance check to accept our simple mock
|
||||
original_isinstance = isinstance
|
||||
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
||||
|
||||
def custom_isinstance(obj, class_or_tuple):
|
||||
if obj is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
return original_isinstance(obj, class_or_tuple)
|
||||
|
||||
mock_isinstance.side_effect = custom_isinstance
|
||||
|
||||
# Call cache_batch_outputs
|
||||
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
|
||||
|
||||
# Verify the npz file was created
|
||||
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
|
||||
|
||||
# Verify the is_disk_cached_outputs_expected method
|
||||
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
|
||||
|
||||
# Test loading from npz
|
||||
loaded_data = caching_strategy.load_outputs_npz(cache_file)
|
||||
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
|
||||
|
||||
|
||||
def test_lumina_latents_caching_strategy():
|
||||
# Create a temporary directory for caching
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Prepare a mock absolute path
|
||||
abs_path = os.path.join(tmpdir, "test_image.png")
|
||||
|
||||
# Use smaller image size for faster testing
|
||||
image_size = (64, 64)
|
||||
|
||||
# Create a smaller dummy image for testing
|
||||
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
|
||||
# Create the caching strategy
|
||||
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
|
||||
|
||||
# Create a simple mock VAE
|
||||
class MockVAE:
|
||||
def __init__(self):
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
def encode(self, x):
|
||||
# Return smaller encoded tensor for faster processing
|
||||
encoded = torch.randn(1, 4, 8, 8, device=x.device)
|
||||
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
|
||||
|
||||
# Prepare a mock batch
|
||||
class MockImageInfo:
|
||||
def __init__(self, path, image):
|
||||
self.absolute_path = path
|
||||
self.image = image
|
||||
self.image_path = path
|
||||
self.bucket_reso = image_size
|
||||
self.resized_size = image_size
|
||||
self.resize_interpolation = "lanczos"
|
||||
# Specify full path to the latents npz file
|
||||
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
|
||||
|
||||
batch = [MockImageInfo(abs_path, test_image)]
|
||||
|
||||
# Call cache_batch_latents
|
||||
mock_vae = MockVAE()
|
||||
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
|
||||
|
||||
# Generate the expected npz path
|
||||
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
|
||||
|
||||
# Verify the file was created
|
||||
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
|
||||
|
||||
# Verify is_disk_cached_latents_expected
|
||||
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
|
||||
|
||||
# Test loading from disk
|
||||
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
|
||||
assert len(loaded_data) == 5 # Check for 5 expected elements
|
||||
408
tests/test_custom_offloading_utils.py
Normal file
408
tests/test_custom_offloading_utils.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from library.custom_offloading_utils import (
|
||||
synchronize_device,
|
||||
swap_weight_devices_cuda,
|
||||
swap_weight_devices_no_cuda,
|
||||
weighs_to_device,
|
||||
Offloader,
|
||||
ModelOffloader
|
||||
)
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, block_idx: int):
|
||||
super().__init__()
|
||||
self.block_idx = block_idx
|
||||
self.linear1 = nn.Linear(10, 5)
|
||||
self.linear2 = nn.Linear(5, 10)
|
||||
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.linear2(x)
|
||||
x = self.seq(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self, num_blocks=16):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(i)
|
||||
for i in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
# Device Synchronization Tests
|
||||
@patch('torch.cuda.synchronize')
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_synchronize(mock_cuda_sync):
|
||||
device = torch.device('cuda')
|
||||
synchronize_device(device)
|
||||
mock_cuda_sync.assert_called_once()
|
||||
|
||||
@patch('torch.xpu.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
|
||||
def test_xpu_synchronize(mock_xpu_sync):
|
||||
device = torch.device('xpu')
|
||||
synchronize_device(device)
|
||||
mock_xpu_sync.assert_called_once()
|
||||
|
||||
@patch('torch.mps.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
|
||||
def test_mps_synchronize(mock_mps_sync):
|
||||
device = torch.device('mps')
|
||||
synchronize_device(device)
|
||||
mock_mps_sync.assert_called_once()
|
||||
|
||||
|
||||
# Weights to Device Tests
|
||||
def test_weights_to_device():
|
||||
# Create a simple model with weights
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
)
|
||||
|
||||
# Start with CPU tensors
|
||||
device = torch.device('cpu')
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.device == device
|
||||
|
||||
# Move to mock CUDA device
|
||||
mock_device = torch.device('cuda')
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
|
||||
weighs_to_device(model, mock_device)
|
||||
|
||||
# Since we mocked the to() function, we can only verify modules were processed
|
||||
# but can't check actual device movement
|
||||
|
||||
|
||||
# Swap Weight Devices Tests
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_swap_weight_devices_cuda():
|
||||
device = torch.device('cuda')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
# Move layer to CUDA to move to CPU
|
||||
layer_to_cpu.to(device)
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
assert layer_to_cpu.device.type == 'cpu'
|
||||
assert layer_to_cuda.device.type == 'cuda'
|
||||
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
def test_swap_weight_devices_no_cuda(mock_sync_device):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
# Verify synchronize_device was called twice
|
||||
assert mock_sync_device.call_count == 2
|
||||
|
||||
|
||||
# Offloader Tests
|
||||
@pytest.fixture
|
||||
def offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return Offloader(
|
||||
num_blocks=4,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_offloader_init(offloader):
|
||||
assert offloader.num_blocks == 4
|
||||
assert offloader.blocks_to_swap == 2
|
||||
assert hasattr(offloader, 'thread_pool')
|
||||
assert offloader.futures == {}
|
||||
assert offloader.cuda_available == (offloader.device.type == 'cuda')
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
|
||||
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
|
||||
block_to_cpu = SimpleModel()
|
||||
block_to_cuda = SimpleModel()
|
||||
|
||||
# Force test for CUDA device
|
||||
offloader.cuda_available = True
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_not_called()
|
||||
|
||||
# Reset mocks
|
||||
mock_cuda.reset_mock()
|
||||
mock_no_cuda.reset_mock()
|
||||
|
||||
# Force test for non-CUDA device
|
||||
offloader.cuda_available = False
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_not_called()
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
|
||||
def test_submit_move_blocks(mock_swap, offloader):
|
||||
blocks = [SimpleModel() for _ in range(4)]
|
||||
block_idx_to_cpu = 0
|
||||
block_idx_to_cuda = 2
|
||||
|
||||
# Mock the thread pool to execute synchronously
|
||||
future = MagicMock()
|
||||
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
|
||||
offloader.thread_pool.submit = MagicMock(return_value=future)
|
||||
|
||||
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
|
||||
# Check that the future is stored with the correct key
|
||||
assert block_idx_to_cuda in offloader.futures
|
||||
|
||||
|
||||
def test_wait_blocks_move(offloader):
|
||||
block_idx = 2
|
||||
|
||||
# Test with no future for the block
|
||||
offloader._wait_blocks_move(block_idx) # Should not raise
|
||||
|
||||
# Create a fake future and test waiting
|
||||
future = MagicMock()
|
||||
future.result.return_value = (0, block_idx)
|
||||
offloader.futures[block_idx] = future
|
||||
|
||||
offloader._wait_blocks_move(block_idx)
|
||||
|
||||
# Check that the future was removed
|
||||
assert block_idx not in offloader.futures
|
||||
future.result.assert_called_once()
|
||||
|
||||
|
||||
# ModelOffloader Tests
|
||||
@pytest.fixture
|
||||
def model_offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
return ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_model_offloader_init(model_offloader):
|
||||
assert model_offloader.num_blocks == 4
|
||||
assert model_offloader.blocks_to_swap == 2
|
||||
assert hasattr(model_offloader, 'thread_pool')
|
||||
assert model_offloader.futures == {}
|
||||
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
|
||||
|
||||
|
||||
def test_create_backward_hook():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test hook creation for swapping case (block 0)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 0)
|
||||
assert hook_swap is None
|
||||
|
||||
# Test hook creation for waiting case (block 1)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_wait is not None
|
||||
|
||||
# Test hook creation for no action case (block 3)
|
||||
hook_none = model_offloader.create_backward_hook(blocks, 3)
|
||||
assert hook_none is None
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_backward_hook_execution(mock_wait, mock_submit):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test swapping hook (block 1)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_swap is not None
|
||||
hook_swap(model, torch.zeros(1), torch.zeros(1))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test waiting hook (block 2)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 2)
|
||||
assert hook_wait is not None
|
||||
hook_wait(model, torch.zeros(1), torch.zeros(1))
|
||||
assert mock_wait.call_count == 2
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.weighs_to_device')
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
@patch('library.custom_offloading_utils.clean_memory_on_device')
|
||||
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
|
||||
with patch.object(nn.Module, 'to'):
|
||||
model_offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Check that weighs_to_device was called for each block
|
||||
assert mock_weights_to_device.call_count == 4
|
||||
|
||||
# Check that synchronize_device and clean_memory_on_device were called
|
||||
mock_sync.assert_called_once_with(model_offloader.device)
|
||||
mock_clean.assert_called_once_with(model_offloader.device)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_wait_for_block(mock_wait, model_offloader):
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.wait_for_block(1)
|
||||
mock_wait.assert_not_called()
|
||||
|
||||
# Test with blocks_to_swap=2
|
||||
model_offloader.blocks_to_swap = 2
|
||||
block_idx = 1
|
||||
model_offloader.wait_for_block(block_idx)
|
||||
mock_wait.assert_called_once_with(block_idx)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
def test_submit_move_blocks(mock_submit, model_offloader):
|
||||
model = SimpleModel()
|
||||
blocks = model.blocks
|
||||
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.submit_move_blocks(blocks, 1)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
model_offloader.blocks_to_swap = 2
|
||||
|
||||
# Test within swap range
|
||||
block_idx = 1
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test outside swap range
|
||||
block_idx = 3
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
|
||||
# Integration test for offloading in a realistic scenario
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_offloading_integration():
|
||||
device = torch.device('cuda')
|
||||
# Create a mini model with 4 blocks
|
||||
model = SimpleModel(5)
|
||||
model.to(device)
|
||||
blocks = model.blocks
|
||||
|
||||
# Initialize model offloader
|
||||
offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Prepare blocks for forward pass
|
||||
offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Simulate forward pass with offloading
|
||||
input_tensor = torch.randn(1, 10, device=device)
|
||||
x = input_tensor
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
# Wait for the current block to be ready
|
||||
offloader.wait_for_block(i)
|
||||
|
||||
# Process through the block
|
||||
x = block(x)
|
||||
|
||||
# Schedule moving weights for future blocks
|
||||
offloader.submit_move_blocks(blocks, i)
|
||||
|
||||
# Verify we get a valid output
|
||||
assert x.shape == (1, 10)
|
||||
assert not torch.isnan(x).any()
|
||||
|
||||
|
||||
# Error handling tests
|
||||
def test_offloader_assertion_error():
|
||||
with pytest.raises(AssertionError):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = nn.Linear(10, 5) # Different class
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests when file is executed directly
|
||||
import sys
|
||||
|
||||
# Configure pytest command line arguments
|
||||
pytest_args = [
|
||||
"-v", # Verbose output
|
||||
"--color=yes", # Colored output
|
||||
__file__, # Run tests in this file
|
||||
]
|
||||
|
||||
# Add optional arguments from command line
|
||||
if len(sys.argv) > 1:
|
||||
pytest_args.extend(sys.argv[1:])
|
||||
|
||||
# Print info about test execution
|
||||
print(f"Running tests with PyTorch {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Run the tests
|
||||
sys.exit(pytest.main(pytest_args))
|
||||
6
tests/test_fine_tune.py
Normal file
6
tests/test_fine_tune.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import fine_tune
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_flux_train.py
Normal file
6
tests/test_flux_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import flux_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_flux_train_network.py
Normal file
5
tests/test_flux_train_network.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import flux_train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the flux_train_network imports without syntax errors
|
||||
assert True
|
||||
177
tests/test_lumina_train_network.py
Normal file
177
tests/test_lumina_train_network.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
import argparse
|
||||
|
||||
from library import lumina_models, lumina_util
|
||||
from lumina_train_network import LuminaNetworkTrainer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lumina_trainer():
|
||||
return LuminaNetworkTrainer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
args = MagicMock()
|
||||
args.pretrained_model_name_or_path = "test_path"
|
||||
args.disable_mmap_load_safetensors = False
|
||||
args.use_flash_attn = False
|
||||
args.use_sage_attn = False
|
||||
args.fp8_base = False
|
||||
args.blocks_to_swap = None
|
||||
args.gemma2 = "test_gemma2_path"
|
||||
args.ae = "test_ae_path"
|
||||
args.cache_text_encoder_outputs = True
|
||||
args.cache_text_encoder_outputs_to_disk = False
|
||||
args.network_train_unet_only = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_accelerator():
|
||||
accelerator = MagicMock()
|
||||
accelerator.device = torch.device("cpu")
|
||||
accelerator.prepare.side_effect = lambda x, **kwargs: x
|
||||
accelerator.unwrap_model.side_effect = lambda x: x
|
||||
return accelerator
|
||||
|
||||
|
||||
def test_assert_extra_args(lumina_trainer, mock_args):
|
||||
train_dataset_group = MagicMock()
|
||||
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
val_dataset_group = MagicMock()
|
||||
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
|
||||
# Test with default settings
|
||||
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# Verify verify_bucket_reso_steps was called for both groups
|
||||
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
|
||||
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
|
||||
|
||||
# Check text encoder output caching
|
||||
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
|
||||
assert mock_args.cache_text_encoder_outputs is True
|
||||
|
||||
|
||||
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
|
||||
# Patch lumina_util methods
|
||||
with (
|
||||
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
|
||||
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
|
||||
patch("library.lumina_util.load_ae") as mock_load_ae,
|
||||
):
|
||||
# Create mock models
|
||||
mock_model = MagicMock(spec=lumina_models.NextDiT)
|
||||
mock_model.dtype = torch.float32
|
||||
mock_gemma2 = MagicMock()
|
||||
mock_ae = MagicMock()
|
||||
|
||||
mock_load_lumina_model.return_value = mock_model
|
||||
mock_load_gemma2.return_value = mock_gemma2
|
||||
mock_load_ae.return_value = mock_ae
|
||||
|
||||
# Test load_target_model
|
||||
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
|
||||
|
||||
# Verify calls and return values
|
||||
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
|
||||
assert gemma2_list == [mock_gemma2]
|
||||
assert ae == mock_ae
|
||||
assert model == mock_model
|
||||
|
||||
# Verify load calls
|
||||
mock_load_lumina_model.assert_called_once()
|
||||
mock_load_gemma2.assert_called_once()
|
||||
mock_load_ae.assert_called_once()
|
||||
|
||||
|
||||
def test_get_strategies(lumina_trainer, mock_args):
|
||||
# Test tokenize strategy
|
||||
try:
|
||||
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
|
||||
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
|
||||
# Test latents caching strategy
|
||||
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
|
||||
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
|
||||
|
||||
# Test text encoding strategy
|
||||
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
|
||||
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
|
||||
|
||||
|
||||
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
|
||||
# Call assert_extra_args to set train_gemma2
|
||||
train_dataset_group = MagicMock()
|
||||
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
val_dataset_group = MagicMock()
|
||||
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# With text encoder caching enabled
|
||||
mock_args.skip_cache_check = False
|
||||
mock_args.text_encoder_batch_size = 16
|
||||
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
||||
|
||||
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
|
||||
assert strategy.cache_to_disk is False # based on mock_args
|
||||
|
||||
# With text encoder caching disabled
|
||||
mock_args.cache_text_encoder_outputs = False
|
||||
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
||||
assert strategy is None
|
||||
|
||||
|
||||
def test_noise_scheduler(lumina_trainer, mock_args):
|
||||
device = torch.device("cpu")
|
||||
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
|
||||
|
||||
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
|
||||
assert noise_scheduler.num_train_timesteps == 1000
|
||||
assert hasattr(lumina_trainer, "noise_scheduler_copy")
|
||||
|
||||
|
||||
def test_sai_model_spec(lumina_trainer, mock_args):
|
||||
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
|
||||
mock_get_spec.return_value = "test_spec"
|
||||
spec = lumina_trainer.get_sai_model_spec(mock_args)
|
||||
assert spec == "test_spec"
|
||||
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
|
||||
|
||||
|
||||
def test_update_metadata(lumina_trainer, mock_args):
|
||||
metadata = {}
|
||||
lumina_trainer.update_metadata(metadata, mock_args)
|
||||
|
||||
assert "ss_weighting_scheme" in metadata
|
||||
assert "ss_logit_mean" in metadata
|
||||
assert "ss_logit_std" in metadata
|
||||
assert "ss_mode_scale" in metadata
|
||||
assert "ss_timestep_sampling" in metadata
|
||||
assert "ss_sigmoid_scale" in metadata
|
||||
assert "ss_model_prediction_type" in metadata
|
||||
assert "ss_discrete_flow_shift" in metadata
|
||||
|
||||
|
||||
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
|
||||
# Test with text encoder output caching, but not training text encoder
|
||||
mock_args.cache_text_encoder_outputs = True
|
||||
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False):
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is True
|
||||
|
||||
# Test with text encoder output caching and training text encoder
|
||||
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True):
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is False
|
||||
|
||||
# Test with no text encoder output caching
|
||||
mock_args.cache_text_encoder_outputs = False
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is False
|
||||
6
tests/test_sd3_train.py
Normal file
6
tests/test_sd3_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sd3_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_sd3_train_network.py
Normal file
5
tests/test_sd3_train_network.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import sd3_train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the flux_train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_sdxl_train.py
Normal file
6
tests/test_sdxl_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sdxl_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_sdxl_train_network.py
Normal file
6
tests/test_sdxl_train_network.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import sdxl_train_network
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
6
tests/test_train.py
Normal file
6
tests/test_train.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import train_db
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_train_network.py
Normal file
5
tests/test_train_network.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
5
tests/test_train_textual_inversion.py
Normal file
5
tests/test_train_textual_inversion.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import train_textual_inversion
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -175,7 +175,7 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
@@ -414,12 +414,13 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
@@ -1443,11 +1444,13 @@ class NetworkTrainer:
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
weight_norms = network.weight_norms()
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
@@ -1465,6 +1468,7 @@ class NetworkTrainer:
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
progress_bar.unpause()
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -1678,6 +1682,7 @@ class NetworkTrainer:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
progress_bar.unpause()
|
||||
optimizer_train_fn()
|
||||
|
||||
# end of epoch
|
||||
|
||||
Reference in New Issue
Block a user