mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-07 05:58:56 +00:00
Compare commits
10 Commits
feature-ch
...
new_cache
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3a85060ef | ||
|
|
f2322a23e2 | ||
|
|
70423ec61d | ||
|
|
28e9352cc5 | ||
|
|
b72b9eaf11 | ||
|
|
744cf03136 | ||
|
|
2238b94e7b | ||
|
|
665c04e649 | ||
|
|
3677094256 | ||
|
|
bdac55ebbc |
@@ -1,9 +0,0 @@
|
||||
## About This File
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## 1. Project Context
|
||||
Here is the essential context for our project. Please read and understand it thoroughly.
|
||||
|
||||
### Project Overview
|
||||
@./context/01-overview.md
|
||||
@@ -1,101 +0,0 @@
|
||||
This file provides the overview and guidance for developers working with the codebase, including setup instructions, architecture details, and common commands.
|
||||
|
||||
## Project Architecture
|
||||
|
||||
### Core Training Framework
|
||||
The codebase is built around a **strategy pattern architecture** that supports multiple diffusion model families:
|
||||
|
||||
- **`library/strategy_base.py`**: Base classes for tokenization, text encoding, latent caching, and training strategies
|
||||
- **`library/strategy_*.py`**: Model-specific implementations for SD, SDXL, SD3, FLUX, etc.
|
||||
- **`library/train_util.py`**: Core training utilities shared across all model types
|
||||
- **`library/config_util.py`**: Configuration management with TOML support
|
||||
|
||||
### Model Support Structure
|
||||
Each supported model family has a consistent structure:
|
||||
- **Training script**: `{model}_train.py` (full fine-tuning), `{model}_train_network.py` (LoRA/network training)
|
||||
- **Model utilities**: `library/{model}_models.py`, `library/{model}_train_utils.py`, `library/{model}_utils.py`
|
||||
- **Networks**: `networks/lora_{model}.py`, `networks/oft_{model}.py` for adapter training
|
||||
|
||||
### Supported Models
|
||||
- **Stable Diffusion 1.x**: `train*.py`, `library/train_util.py`, `train_db.py` (for DreamBooth)
|
||||
- **SDXL**: `sdxl_train*.py`, `library/sdxl_*`
|
||||
- **SD3**: `sd3_train*.py`, `library/sd3_*`
|
||||
- **FLUX.1**: `flux_train*.py`, `library/flux_*`
|
||||
|
||||
### Key Components
|
||||
|
||||
#### Memory Management
|
||||
- **Block swapping**: CPU-GPU memory optimization via `--blocks_to_swap` parameter, works with custom offloading. Only available for models with transformer architectures like SD3 and FLUX.1.
|
||||
- **Custom offloading**: `library/custom_offloading_utils.py` for advanced memory management
|
||||
- **Gradient checkpointing**: Memory reduction during training
|
||||
|
||||
#### Training Features
|
||||
- **LoRA training**: Low-rank adaptation networks in `networks/lora*.py`
|
||||
- **ControlNet training**: Conditional generation control
|
||||
- **Textual Inversion**: Custom embedding training
|
||||
- **Multi-resolution training**: Bucket-based aspect ratio handling
|
||||
- **Validation loss**: Real-time training monitoring, only for LoRA training
|
||||
|
||||
#### Configuration System
|
||||
Dataset configuration uses TOML files with structured validation:
|
||||
```toml
|
||||
[datasets.sample_dataset]
|
||||
resolution = 1024
|
||||
batch_size = 2
|
||||
|
||||
[[datasets.sample_dataset.subsets]]
|
||||
image_dir = "path/to/images"
|
||||
caption_extension = ".txt"
|
||||
```
|
||||
|
||||
## Common Development Commands
|
||||
|
||||
### Training Commands Pattern
|
||||
All training scripts follow this general pattern:
|
||||
```bash
|
||||
accelerate launch --mixed_precision bf16 {script_name}.py \
|
||||
--pretrained_model_name_or_path model.safetensors \
|
||||
--dataset_config config.toml \
|
||||
--output_dir output \
|
||||
--output_name model_name \
|
||||
[model-specific options]
|
||||
```
|
||||
|
||||
### Memory Optimization
|
||||
For low VRAM environments, use block swapping:
|
||||
```bash
|
||||
# Add to any training command for memory reduction
|
||||
--blocks_to_swap 10 # Swap 10 blocks to CPU (adjust number as needed)
|
||||
```
|
||||
|
||||
### Utility Scripts
|
||||
Located in `tools/` directory:
|
||||
- `tools/merge_lora.py`: Merge LoRA weights into base models
|
||||
- `tools/cache_latents.py`: Pre-cache VAE latents for faster training
|
||||
- `tools/cache_text_encoder_outputs.py`: Pre-cache text encoder outputs
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Strategy Pattern Implementation
|
||||
When adding support for new models, implement the four core strategies:
|
||||
1. `TokenizeStrategy`: Text tokenization handling
|
||||
2. `TextEncodingStrategy`: Text encoder forward pass
|
||||
3. `LatentsCachingStrategy`: VAE encoding/caching
|
||||
4. `TextEncoderOutputsCachingStrategy`: Text encoder output caching
|
||||
|
||||
### Testing Approach
|
||||
- Unit tests focus on utility functions and model loading
|
||||
- Integration tests validate training script syntax and basic execution
|
||||
- Most tests use mocks to avoid requiring actual model files
|
||||
- Add tests for new model support in `tests/test_{model}_*.py`
|
||||
|
||||
### Configuration System
|
||||
- Use `config_util.py` dataclasses for type-safe configuration
|
||||
- Support both command-line arguments and TOML file configuration
|
||||
- Validate configuration early in training scripts to prevent runtime errors
|
||||
|
||||
### Memory Management
|
||||
- Always consider VRAM limitations when implementing features
|
||||
- Use gradient checkpointing for large models
|
||||
- Implement block swapping for models with transformer architectures
|
||||
- Cache intermediate results (latents, text embeddings) when possible
|
||||
@@ -1,9 +0,0 @@
|
||||
## About This File
|
||||
|
||||
This file provides guidance to Gemini CLI (https://github.com/google-gemini/gemini-cli) when working with code in this repository.
|
||||
|
||||
## 1. Project Context
|
||||
Here is the essential context for our project. Please read and understand it thoroughly.
|
||||
|
||||
### Project Overview
|
||||
@./context/01-overview.md
|
||||
3
.github/FUNDING.yml
vendored
3
.github/FUNDING.yml
vendored
@@ -1,3 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: kohya-ss
|
||||
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@@ -12,9 +12,6 @@ on:
|
||||
- dev
|
||||
- sd3
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -43,7 +40,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
3
.github/workflows/typos.yml
vendored
3
.github/workflows/typos.yml
vendored
@@ -12,9 +12,6 @@ on:
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -6,8 +6,3 @@ venv
|
||||
build
|
||||
.vscode
|
||||
wandb
|
||||
CLAUDE.md
|
||||
GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
|
||||
101
README.md
101
README.md
@@ -9,47 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Jul 10, 2025:
|
||||
- [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.
|
||||
|
||||
May 1, 2025:
|
||||
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
|
||||
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
Apr 27, 2025:
|
||||
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
|
||||
- See [here](#sample-image-generation-during-training) for details.
|
||||
- If you have any issues with this, please let us know.
|
||||
|
||||
Apr 6, 2025:
|
||||
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
|
||||
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
|
||||
|
||||
Mar 30, 2025:
|
||||
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
|
||||
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
|
||||
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
|
||||
|
||||
Mar 20, 2025:
|
||||
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
|
||||
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
|
||||
|
||||
Mar 6, 2025:
|
||||
|
||||
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
|
||||
|
||||
Feb 26, 2025:
|
||||
|
||||
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
|
||||
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
|
||||
|
||||
Jan 25, 2025:
|
||||
|
||||
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
|
||||
@@ -57,30 +21,46 @@ Jan 25, 2025:
|
||||
- It will be added to other scripts as well.
|
||||
- As a current limitation, validation loss is not supported when `--block_to_swap` is specified, or when schedule-free optimizer is used.
|
||||
|
||||
## For Developers Using AI Coding Agents
|
||||
Dec 15, 2024:
|
||||
|
||||
This repository provides recommended instructions to help AI agents like Claude and Gemini understand our project context and coding standards.
|
||||
- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!
|
||||
- Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`.
|
||||
- Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler.
|
||||
|
||||
To use them, you need to opt-in by creating your own configuration file in the project root.
|
||||
Dec 7, 2024:
|
||||
|
||||
**Quick Setup:**
|
||||
- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds!
|
||||
<!--
|
||||
Also, the ControlNet training script for SD has been changed from `train_controlnet.py` to `train_control_net.py`.
|
||||
- `train_controlnet.py` is still available, but it will be removed in the future.
|
||||
-->
|
||||
|
||||
1. Create a `CLAUDE.md` and/or `GEMINI.md` file in the project root.
|
||||
2. Add the following line to your `CLAUDE.md` to import the repository's recommended prompt:
|
||||
- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.
|
||||
|
||||
```markdown
|
||||
@./.ai/claude.prompt.md
|
||||
```
|
||||
Dec 3, 2024:
|
||||
|
||||
or for Gemini:
|
||||
-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training).
|
||||
|
||||
```markdown
|
||||
@./.ai/gemini.prompt.md
|
||||
```
|
||||
Dec 2, 2024:
|
||||
|
||||
3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`).
|
||||
- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details.
|
||||
- Not fully tested. Feedback is welcome.
|
||||
- 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution.
|
||||
- Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required.
|
||||
- Multi-GPU training is not tested.
|
||||
|
||||
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so it won't be committed to the repository.
|
||||
Dec 1, 2024:
|
||||
|
||||
- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris!
|
||||
- Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available.
|
||||
|
||||
- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO!
|
||||
|
||||
Nov 14, 2024:
|
||||
|
||||
- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM.
|
||||
- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved.
|
||||
- There may be bugs due to the significant changes. Feedback is welcome.
|
||||
|
||||
## FLUX.1 training
|
||||
|
||||
@@ -759,8 +739,6 @@ Not available yet.
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
更新履歴は[ページ末尾](#change-history)に移しました。
|
||||
|
||||
Latest update: 2025-03-21 (Version 0.9.1)
|
||||
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
|
||||
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
||||
@@ -868,14 +846,6 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
@@ -912,11 +882,6 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
|
||||
|
||||
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
|
||||
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
|
||||
|
||||
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
||||
|
||||
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
@@ -1350,13 +1315,11 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
|
||||
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
@@ -152,7 +152,6 @@ These options are related to subset configuration.
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` | (not specified) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
|
||||
@@ -166,8 +165,6 @@ These options are related to subset configuration.
|
||||
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
|
||||
* `enable_wildcard`
|
||||
* Enables wildcard notation. This will be explained later.
|
||||
* `resize_interpolation`
|
||||
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
|
||||
|
||||
### DreamBooth-specific options
|
||||
|
||||
|
||||
@@ -144,7 +144,6 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` |(通常は設定しません) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
||||
@@ -163,9 +162,6 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
* `enable_wildcard`
|
||||
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
||||
|
||||
* `resize_interpolation`
|
||||
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
|
||||
|
||||
### DreamBooth 方式専用のオプション
|
||||
|
||||
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
||||
|
||||
@@ -1,302 +0,0 @@
|
||||
Status: reviewed
|
||||
|
||||
# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド
|
||||
|
||||
This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository.
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE).
|
||||
|
||||
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the train_network.py guide (to be documented). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* The `sd-scripts` repository has been cloned and the Python environment is ready.
|
||||
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
|
||||
* Lumina Image 2.0 model files for training are available.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、`train_network.py`のガイド(作成中)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
**前提条件:**
|
||||
|
||||
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
|
||||
* 学習対象のLumina Image 2.0モデルファイルが準備できていること。
|
||||
</details>
|
||||
|
||||
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||
|
||||
`lumina_train_network.py` is based on `train_network.py` but modified for Lumina Image 2.0. Main differences are:
|
||||
|
||||
* **Target models:** Lumina Image 2.0 models.
|
||||
* **Model structure:** Uses Next-DiT (Transformer based) instead of U-Net and employs a single text encoder (Gemma2). The AutoEncoder (AE) is not compatible with SDXL/SD3/FLUX.
|
||||
* **Arguments:** Options exist to specify the Lumina Image 2.0 model, Gemma2 text encoder and AE. With a single `.safetensors` file, these components are typically provided separately.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
|
||||
* **Lumina specific options:** Additional parameters for timestep sampling, model prediction type, discrete flow shift, and system prompt.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
`lumina_train_network.py`は`train_network.py`をベースに、Lumina Image 2.0モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** Lumina Image 2.0モデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにNext-DiT (Transformerベース) を使用します。Text EncoderとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。
|
||||
* **引数:** Lumina Image 2.0モデル、Gemma2 Text Encoder、AEを指定する引数があります。通常、これらのコンポーネントは個別に提供されます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はLumina Image 2.0の学習では使用されません。
|
||||
* **Lumina特有の引数:** タイムステップのサンプリング、モデル予測タイプ、離散フローシフト、システムプロンプトに関する引数が追加されています。
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
|
||||
The following files are required before starting training:
|
||||
|
||||
1. **Training script:** `lumina_train_network.py`
|
||||
2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model.
|
||||
3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder.
|
||||
4. **AutoEncoder (AE) file:** `.safetensors` file for the AE.
|
||||
5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md). In this document we use `my_lumina_dataset_config.toml` as an example.
|
||||
|
||||
|
||||
**Model Files:**
|
||||
* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors))
|
||||
* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors))
|
||||
* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX)
|
||||
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習を開始する前に、以下のファイルが必要です。
|
||||
|
||||
1. **学習スクリプト:** `lumina_train_network.py`
|
||||
2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。
|
||||
3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。
|
||||
4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。
|
||||
5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。
|
||||
* 例として`my_lumina_dataset_config.toml`を使用します。
|
||||
|
||||
**モデルファイル** は英語ドキュメントの通りです。
|
||||
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
|
||||
Execute `lumina_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Lumina Image 2.0 specific options must be supplied.
|
||||
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
|
||||
--pretrained_model_name_or_path="lumina-image-2.safetensors" \
|
||||
--gemma2="gemma-2-2b.safetensors" \
|
||||
--ae="ae.safetensors" \
|
||||
--dataset_config="my_lumina_dataset_config.toml" \
|
||||
--output_dir="./output" \
|
||||
--output_name="my_lumina_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_lumina \
|
||||
--network_dim=8 \
|
||||
--network_alpha=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sampling="nextdit_shift" \
|
||||
--discrete_flow_shift=6.0 \
|
||||
--model_prediction_type="raw" \
|
||||
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs
|
||||
```
|
||||
|
||||
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
学習は、ターミナルから`lumina_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Lumina Image 2.0特有の引数を指定する必要があります。
|
||||
|
||||
以下に、基本的なコマンドライン実行例を示します。
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \
|
||||
--pretrained_model_name_or_path="lumina-image-2.safetensors" \
|
||||
--gemma2="gemma-2-2b.safetensors" \
|
||||
--ae="ae.safetensors" \
|
||||
--dataset_config="my_lumina_dataset_config.toml" \
|
||||
--output_dir="./output" \
|
||||
--output_name="my_lumina_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_lumina \
|
||||
--network_dim=8 \
|
||||
--network_alpha=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sampling="nextdit_shift" \
|
||||
--discrete_flow_shift=6.0 \
|
||||
--model_prediction_type="raw" \
|
||||
--system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs
|
||||
```
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
|
||||
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Lumina Image 2.0 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide.
|
||||
|
||||
#### Model Options / モデル関連
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Lumina model>"` **required** – Path to the Lumina Image 2.0 model.
|
||||
* `--gemma2="<path to Gemma2 model>"` **required** – Path to the Gemma2 text encoder `.safetensors` file.
|
||||
* `--ae="<path to AE model>"` **required** – Path to the AutoEncoder `.safetensors` file.
|
||||
|
||||
#### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ
|
||||
|
||||
* `--gemma2_max_token_length=<integer>` – Max token length for Gemma2. Default is 256.
|
||||
* `--timestep_sampling=<choice>` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `shift`. **Recommended: `nextdit_shift`**
|
||||
* `--discrete_flow_shift=<float>` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`.
|
||||
* `--model_prediction_type=<choice>` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`**
|
||||
* `--system_prompt=<string>` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training.
|
||||
* `--sigmoid_scale=<float>` – Scale factor for sigmoid timestep sampling. Default `1.0`.
|
||||
|
||||
#### Memory and Speed / メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[experimental]** – Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`.
|
||||
* `--cache_text_encoder_outputs` – Cache Gemma2 outputs to reduce memory usage.
|
||||
* `--cache_latents`, `--cache_latents_to_disk` – Cache AE outputs.
|
||||
* `--fp8_base` – Use FP8 precision for the base model.
|
||||
|
||||
#### Network Arguments / ネットワーク引数
|
||||
|
||||
For Lumina Image 2.0, you can specify different dimensions for various components:
|
||||
|
||||
* `--network_args` can include:
|
||||
* `"attn_dim=4"` – Attention dimension
|
||||
* `"mlp_dim=4"` – MLP dimension
|
||||
* `"mod_dim=4"` – Modulation dimension
|
||||
* `"refiner_dim=4"` – Refiner blocks dimension
|
||||
* `"embedder_dims=[4,4,4]"` – Embedder dimensions for x, t, and caption embedders
|
||||
|
||||
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` – Options for Stable Diffusion v1/v2 that are not used for Lumina Image 2.0.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
|
||||
|
||||
#### モデル関連
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to Lumina model>"` **[必須]**
|
||||
* 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイルのパスを指定します。
|
||||
* `--gemma2="<path to Gemma2 model>"` **[必須]**
|
||||
* Gemma2テキストエンコーダーの`.safetensors`ファイルのパスを指定します。
|
||||
* `--ae="<path to AE model>"` **[必須]**
|
||||
* AutoEncoderの`.safetensors`ファイルのパスを指定します。
|
||||
|
||||
#### Lumina Image 2.0 学習パラメータ
|
||||
|
||||
* `--gemma2_max_token_length=<integer>` – Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。
|
||||
* `--timestep_sampling=<choice>` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`shift`です。**推奨: `nextdit_shift`**
|
||||
* `--discrete_flow_shift=<float>` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。
|
||||
* `--model_prediction_type=<choice>` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`**
|
||||
* `--system_prompt=<string>` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。
|
||||
* `--sigmoid_scale=<float>` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[実験的機能]** – TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` – Gemma2の出力をキャッシュしてメモリ使用量を削減します。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` – AEの出力をキャッシュします。
|
||||
* `--fp8_base` – ベースモデルにFP8精度を使用します。
|
||||
|
||||
#### ネットワーク引数
|
||||
|
||||
Lumina Image 2.0では、各コンポーネントに対して異なる次元を指定できます:
|
||||
|
||||
* `--network_args` には以下を含めることができます:
|
||||
* `"attn_dim=4"` – アテンション次元
|
||||
* `"mlp_dim=4"` – MLP次元
|
||||
* `"mod_dim=4"` – モジュレーション次元
|
||||
* `"refiner_dim=4"` – リファイナーブロック次元
|
||||
* `"embedder_dims=[4,4,4]"` – x、t、キャプションエンベッダーのエンベッダー次元
|
||||
|
||||
#### 非互換・非推奨の引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` – Stable Diffusion v1/v2向けの引数のため、Lumina Image 2.0学習では使用されません。
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
|
||||
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
|
||||
|
||||
## 5. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes.
|
||||
|
||||
## 6. Others / その他
|
||||
|
||||
`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`.
|
||||
|
||||
### 6.1. Recommended Settings / 推奨設定
|
||||
|
||||
Based on the contributor's recommendations, here are the suggested settings for optimal training:
|
||||
|
||||
**Key Parameters:**
|
||||
* `--timestep_sampling="nextdit_shift"`
|
||||
* `--discrete_flow_shift=6.0`
|
||||
* `--model_prediction_type="raw"`
|
||||
* `--mixed_precision="bf16"`
|
||||
|
||||
**System Prompts:**
|
||||
* General purpose: `"You are an assistant designed to generate high-quality images based on user prompts."`
|
||||
* High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
|
||||
**Sample Prompts:**
|
||||
Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) parameters:
|
||||
* `--ctr 0.25 --rcfg 1.0` (default values)
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||
|
||||
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
|
||||
|
||||
`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。
|
||||
|
||||
### 6.1. 推奨設定
|
||||
|
||||
コントリビューターの推奨に基づく、最適な学習のための推奨設定:
|
||||
|
||||
**主要パラメータ:**
|
||||
* `--timestep_sampling="nextdit_shift"`
|
||||
* `--discrete_flow_shift=6.0`
|
||||
* `--model_prediction_type="raw"`
|
||||
* `--mixed_precision="bf16"`
|
||||
|
||||
**システムプロンプト:**
|
||||
* 汎用目的: `"You are an assistant designed to generate high-quality images based on user prompts."`
|
||||
* 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."`
|
||||
|
||||
**サンプルプロンプト:**
|
||||
サンプルプロンプトには CFG truncate (`--ctr`) と Renorm CFG (`--rcfg`) パラメータを含めることができます:
|
||||
* `--ctr 0.25 --rcfg 1.0` (デフォルト値)
|
||||
|
||||
</details>
|
||||
@@ -178,7 +178,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
232
finetune/caption_images_by_florence2.py
Normal file
232
finetune/caption_images_by_florence2.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# add caption to images by Florence-2
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
from library import device_utils, train_util, dataset_metadata_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import tagger_utils
|
||||
|
||||
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.load_archive == (
|
||||
args.metadata is not None
|
||||
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
|
||||
|
||||
device = args.device if args.device is not None else device_utils.get_preferred_device()
|
||||
if type(device) is str:
|
||||
device = torch.device(device)
|
||||
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
|
||||
logger.info(f"device: {device}, dtype: {torch_dtype}")
|
||||
|
||||
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
|
||||
|
||||
support_flash_attn = False
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
support_flash_attn = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if support_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
else:
|
||||
logger.info(
|
||||
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
|
||||
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from unittest.mock import patch
|
||||
|
||||
def fixed_get_imports(filename) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
# workaround for unnecessary flash_attn requirement
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
|
||||
model.eval()
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
||||
|
||||
# 画像を読み込む
|
||||
if not args.load_archive:
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
else:
|
||||
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
|
||||
os.path.join(args.train_data_dir, "*.tar")
|
||||
)
|
||||
image_paths = [Path(archive_file) for archive_file in archive_files]
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
# define preprocess_image function
|
||||
def preprocess_image(image: Image.Image):
|
||||
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
|
||||
return inputs
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
# Loader returns: list of (image_path, processed_image_or_something, image_size)
|
||||
if args.load_archive:
|
||||
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
|
||||
else:
|
||||
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
|
||||
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
|
||||
|
||||
def run_batch(
|
||||
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
|
||||
images_metadata: Optional[dict[str, Any]],
|
||||
caption_index: Optional[int] = None,
|
||||
):
|
||||
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
|
||||
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
|
||||
|
||||
if args.debug:
|
||||
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
num_beams=args.num_beams,
|
||||
)
|
||||
if args.debug:
|
||||
logger.info(f"generate done: {generated_ids.shape}")
|
||||
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
|
||||
if args.debug:
|
||||
logger.info(f"decode done: {len(generated_texts)}")
|
||||
|
||||
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
|
||||
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
|
||||
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
|
||||
|
||||
caption_text = caption_text.strip().replace("<pad>", "")
|
||||
original_caption_text = caption_text
|
||||
|
||||
if args.remove_mood:
|
||||
p = caption_text.find("The overall ")
|
||||
if p != -1:
|
||||
caption_text = caption_text[:p].strip()
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
if images_metadata is None:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(caption_text + "\n")
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is None:
|
||||
image_md = {"image_size": list(image_size)}
|
||||
images_metadata[image_path] = image_md
|
||||
if "caption" not in image_md:
|
||||
image_md["caption"] = []
|
||||
if caption_index is None:
|
||||
image_md["caption"].append(caption_text)
|
||||
else:
|
||||
while len(image_md["caption"]) <= caption_index:
|
||||
image_md["caption"].append("")
|
||||
image_md["caption"][caption_index] = caption_text
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tCaption: {caption_text}")
|
||||
if args.remove_mood and original_caption_text != caption_text:
|
||||
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
|
||||
|
||||
for data_entry in tqdm(loader, smoothing=0.0):
|
||||
b_imgs = data_entry
|
||||
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs, images_metadata, args.caption_index)
|
||||
|
||||
if args.metadata is not None:
|
||||
logger.info(f"saving metadata file: {args.metadata}")
|
||||
with open(args.metadata, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
|
||||
)
|
||||
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
|
||||
parser.add_argument(
|
||||
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=3,
|
||||
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
|
||||
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
tagger_utils.add_archive_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -180,7 +180,7 @@ def main(args):
|
||||
|
||||
# バッチへ追加
|
||||
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
|
||||
image_info.latents_npz = npz_file_name
|
||||
image_info.latents_cache_path = npz_file_name
|
||||
image_info.bucket_reso = reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.image = image
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -10,14 +13,18 @@ from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging, resize_image
|
||||
from library import dataset_metadata_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import pil_resize
|
||||
import tagger_utils
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
@@ -42,7 +49,10 @@ def preprocess_image(image):
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
|
||||
if size > IMAGE_SIZE:
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
|
||||
else:
|
||||
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
@@ -60,13 +70,14 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (image, img_path)
|
||||
return (image, img_path, size)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
@@ -80,6 +91,10 @@ def collate_fn_remove_corrupted(batch):
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.load_archive == (
|
||||
args.metadata is not None
|
||||
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
|
||||
|
||||
# model location is model_dir + repo_id
|
||||
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
|
||||
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
|
||||
@@ -97,19 +112,15 @@ def main(args):
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
args.repo_id,
|
||||
file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
cache_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
@@ -150,15 +161,19 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
|
||||
provider_options=[{"device_type": "GPU_FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else (
|
||||
["ROCMExecutionProvider"]
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -204,7 +219,9 @@ def main(args):
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
assert (
|
||||
len(tags) == 2
|
||||
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
@@ -217,9 +234,15 @@ def main(args):
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
|
||||
# 画像を読み込む
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
if not args.load_archive:
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
else:
|
||||
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
|
||||
os.path.join(args.train_data_dir, "*.tar")
|
||||
)
|
||||
image_paths = [Path(archive_file) for archive_file in archive_files]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -232,19 +255,23 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(
|
||||
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
|
||||
images_metadata: Optional[dict[str, Any]],
|
||||
tags_index: Optional[int] = None,
|
||||
):
|
||||
imgs = np.array([im for _, im, _ in list_of_path_img_size])
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(list_of_path_img_size)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
@@ -266,7 +293,7 @@ def main(args):
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
@@ -282,7 +309,7 @@ def main(args):
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
|
||||
@@ -305,12 +332,24 @@ def main(args):
|
||||
tag_text = caption_separator.join(combined_tags)
|
||||
|
||||
if args.append_tags:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
existing_content = None
|
||||
if images_metadata is None:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is not None:
|
||||
tags = image_md.get("tags", None)
|
||||
if tags is not None:
|
||||
if tags_index is None and len(tags) > 0:
|
||||
existing_content = tags[-1]
|
||||
elif tags_index is not None and tags_index < len(tags):
|
||||
existing_content = tags[tags_index]
|
||||
|
||||
if existing_content is not None:
|
||||
# Split the content into tags and store them in a list
|
||||
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
|
||||
|
||||
@@ -320,19 +359,46 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if images_metadata is None:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is None:
|
||||
image_md = {"image_size": list(image_size)}
|
||||
images_metadata[image_path] = image_md
|
||||
if "tags" not in image_md:
|
||||
image_md["tags"] = []
|
||||
if tags_index is None:
|
||||
image_md["tags"].append(tag_text)
|
||||
else:
|
||||
while len(image_md["tags"]) <= tags_index:
|
||||
image_md["tags"].append("")
|
||||
image_md["tags"][tags_index] = tag_text
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
use_loader = False
|
||||
if args.load_archive:
|
||||
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
|
||||
use_loader = True
|
||||
elif args.max_data_loader_n_workers is not None:
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
@@ -340,35 +406,37 @@ def main(args):
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
use_loader = True
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# make batch of image paths
|
||||
loader = []
|
||||
for i in range(0, len(image_paths), args.batch_size):
|
||||
loader.append(image_paths[i : i + args.batch_size])
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
for data_entry in tqdm(loader, smoothing=0.0):
|
||||
if use_loader:
|
||||
b_imgs = data_entry
|
||||
else:
|
||||
b_imgs = []
|
||||
for image_path in data_entry:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
b_imgs.append((image_path, image, size))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs, images_metadata, args.tags_index)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.metadata is not None:
|
||||
logger.info(f"saving metadata file: {args.metadata}")
|
||||
with open(args.metadata, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -381,9 +449,7 @@ def main(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -401,9 +467,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
@@ -442,9 +506,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
@@ -454,20 +516,24 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
"--use_rating_tags",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
"--use_rating_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
"--character_tags_first",
|
||||
action="store_true",
|
||||
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
@@ -496,6 +562,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tags_index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
|
||||
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
|
||||
)
|
||||
tagger_utils.add_archive_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
150
finetune/tagger_utils.py
Normal file
150
finetune/tagger_utils.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Union
|
||||
import zipfile
|
||||
import tarfile
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import dataset_metadata_utils, train_util
|
||||
|
||||
|
||||
class ArchiveImageLoader:
|
||||
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
|
||||
self.archive_paths = archive_paths
|
||||
self.batch_size = batch_size
|
||||
self.preprocess = preprocess
|
||||
self.debug = debug
|
||||
self.current_archive = None
|
||||
self.archive_index = 0
|
||||
self.image_index = 0
|
||||
self.files = None
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
images = []
|
||||
while len(images) < self.batch_size:
|
||||
if self.current_archive is None:
|
||||
if self.archive_index >= len(self.archive_paths):
|
||||
if len(images) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
break # return the remaining images
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
|
||||
|
||||
current_archive_path = self.archive_paths[self.archive_index]
|
||||
if current_archive_path.endswith(".zip"):
|
||||
self.current_archive = zipfile.ZipFile(current_archive_path)
|
||||
self.files = self.current_archive.namelist()
|
||||
elif current_archive_path.endswith(".tar"):
|
||||
self.current_archive = tarfile.open(current_archive_path, "r")
|
||||
self.files = self.current_archive.getnames()
|
||||
else:
|
||||
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
|
||||
|
||||
self.image_index = 0
|
||||
|
||||
# filter by image extensions
|
||||
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"found {len(self.files)} images in the archive")
|
||||
|
||||
new_images = []
|
||||
while len(images) + len(new_images) < self.batch_size:
|
||||
if self.image_index >= len(self.files):
|
||||
break
|
||||
|
||||
file = self.files[self.image_index]
|
||||
archive_and_image_path = (
|
||||
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
|
||||
)
|
||||
self.image_index += 1
|
||||
|
||||
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
|
||||
with archive.open(file) as f:
|
||||
image = Image.open(f).convert("RGB")
|
||||
size = image.size
|
||||
image = self.preprocess(image)
|
||||
return image, size
|
||||
|
||||
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
|
||||
|
||||
# wait for all new_images to load to close the archive
|
||||
new_images = [(image_path, future.result()) for image_path, future in new_images]
|
||||
|
||||
if self.image_index >= len(self.files):
|
||||
self.current_archive.close()
|
||||
self.current_archive = None
|
||||
self.archive_index += 1
|
||||
|
||||
images.extend(new_images)
|
||||
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
class ImageLoader:
|
||||
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
self.preprocess = preprocess
|
||||
self.debug = debug
|
||||
self.image_index = 0
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.image_index >= len(self.image_paths):
|
||||
raise StopIteration
|
||||
|
||||
images = []
|
||||
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
|
||||
|
||||
def load_image(file):
|
||||
image = Image.open(file).convert("RGB")
|
||||
size = image.size
|
||||
image = self.preprocess(image)
|
||||
return image, size
|
||||
|
||||
image_path = self.image_paths[self.image_index]
|
||||
images.append((image_path, self.executor.submit(load_image, image_path)))
|
||||
self.image_index += 1
|
||||
|
||||
images = [(image_path, future.result()) for image_path, future in images]
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
def add_archive_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_archive",
|
||||
action="store_true",
|
||||
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
|
||||
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
|
||||
)
|
||||
@@ -152,15 +152,20 @@ def train(args):
|
||||
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
)
|
||||
t5xxl_max_token_length = (
|
||||
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
|
||||
)
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -199,7 +204,7 @@ def train(args):
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -237,7 +242,12 @@ def train(args):
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
False,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
|
||||
@@ -11,16 +11,6 @@ from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
flux_models,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_flux,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -28,6 +18,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
||||
import train_network
|
||||
|
||||
|
||||
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
@@ -36,12 +29,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -190,13 +178,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
|
||||
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -328,7 +320,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
@@ -346,7 +338,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -381,7 +373,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
@@ -394,6 +387,44 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
"""
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||
with accelerator.autocast():
|
||||
# move flux lower to cpu, and then move flux upper to gpu
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
self.flux_upper.to(accelerator.device)
|
||||
|
||||
# upper model does not require grad
|
||||
with torch.no_grad():
|
||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# lower model requires grad
|
||||
intermediate_img.requires_grad_(True)
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet):
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
"""
|
||||
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
@@ -512,11 +543,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
|
||||
@@ -1,614 +0,0 @@
|
||||
# copy from the official repo: https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py
|
||||
# and modified
|
||||
# licensed under Apache License 2.0
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from .flux_models import (
|
||||
attention,
|
||||
rope,
|
||||
apply_rope,
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
)
|
||||
|
||||
|
||||
def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks):
|
||||
"""
|
||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
|
||||
"""
|
||||
batch_size, vectors, dim = tensor.shape
|
||||
|
||||
block_dict = {}
|
||||
|
||||
# HARD CODED VALUES! lookup table for the generated vectors
|
||||
# TODO: move this into chroma config!
|
||||
# Add 38 single mod blocks
|
||||
for i in range(depth_single_blocks):
|
||||
key = f"single_blocks.{i}.modulation.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 image double blocks
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.img_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 text double blocks
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.txt_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add the final layer
|
||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
||||
# 6.2b version
|
||||
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
||||
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
||||
|
||||
idx = 0 # Index to keep track of the vector slices
|
||||
|
||||
for key in block_dict.keys():
|
||||
if "single_blocks" in key:
|
||||
# Single block: 1 ModulationOut
|
||||
block_dict[key] = ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors
|
||||
|
||||
elif "img_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "txt_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "final_layer" in key:
|
||||
# Final layer: 1 ModulationOut
|
||||
block_dict[key] = [
|
||||
tensor[:, idx : idx + 1, :],
|
||||
tensor[:, idx + 1 : idx + 2, :],
|
||||
]
|
||||
idx += 2 # Advance by 3 vectors
|
||||
|
||||
return block_dict
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
def _modulation_shift_scale_fn(x, scale, shift):
|
||||
return (1 + scale) * x + shift
|
||||
|
||||
|
||||
def _modulation_gate_fn(x, gate, gate_params):
|
||||
return x + gate * gate_params
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def modulation_gate_fn(self, x, gate, gate_params):
|
||||
return _modulation_gate_fn(x, gate, gate_params)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
pe: Tensor,
|
||||
distill_vec: list[ModulationOut],
|
||||
mask: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
# replaced with compiled fn
|
||||
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
# replaced with compiled fn
|
||||
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
# replaced with compiled fn
|
||||
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img = self.modulation_gate_fn(
|
||||
img,
|
||||
img_mod2.gate,
|
||||
self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)),
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
# replaced with compiled fn
|
||||
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt = self.modulation_gate_fn(
|
||||
txt,
|
||||
txt_mod2.gate,
|
||||
self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)),
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def modulation_gate_fn(self, x, gate, gate_params):
|
||||
return _modulation_gate_fn(x, gate, gate_params)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor:
|
||||
mod = distill_vec
|
||||
# replaced with compiled fn
|
||||
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
# replaced with compiled fn
|
||||
# return x + mod.gate * output
|
||||
return self.modulation_gate_fn(x, mod.gate, output)
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor:
|
||||
shift, scale = distill_vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
# replaced with compiled fn
|
||||
# x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.modulation_shift_scale_fn(self.norm_final(x), scale[:, None, :], shift[:, None, :])
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
approximator_in_dim: int
|
||||
approximator_depth: int
|
||||
approximator_hidden_size: int
|
||||
_use_compiled: bool
|
||||
|
||||
|
||||
chroma_params = ChromaParams(
|
||||
in_channels=64,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
approximator_in_dim=64,
|
||||
approximator_depth=5,
|
||||
approximator_hidden_size=5120,
|
||||
_use_compiled=False,
|
||||
)
|
||||
|
||||
|
||||
def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
|
||||
"""
|
||||
Modifies attention mask to allow attention to a few extra padding tokens.
|
||||
|
||||
Args:
|
||||
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
|
||||
max_seq_length: Maximum sequence length of the model
|
||||
num_extra_padding: Number of padding tokens to unmask
|
||||
|
||||
Returns:
|
||||
Modified mask
|
||||
"""
|
||||
# Get the actual sequence length from the mask
|
||||
seq_length = mask.sum(dim=-1)
|
||||
batch_size = mask.shape[0]
|
||||
|
||||
modified_mask = mask.clone()
|
||||
|
||||
for i in range(batch_size):
|
||||
current_seq_len = int(seq_length[i].item())
|
||||
|
||||
# Only add extra padding tokens if there's room
|
||||
if current_seq_len < max_seq_length:
|
||||
# Calculate how many padding tokens we can unmask
|
||||
available_padding = max_seq_length - current_seq_len
|
||||
tokens_to_unmask = min(num_extra_padding, available_padding)
|
||||
|
||||
# Unmask the specified number of padding tokens right after the sequence
|
||||
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
|
||||
|
||||
return modified_mask
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ChromaParams):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
|
||||
# TODO: need proper mapping for this approximator output!
|
||||
# currently the mapping is hardcoded in distribute_modulations function
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
params.approximator_in_dim,
|
||||
self.hidden_size,
|
||||
params.approximator_hidden_size,
|
||||
params.approximator_depth,
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(
|
||||
self.hidden_size,
|
||||
1,
|
||||
self.out_channels,
|
||||
)
|
||||
|
||||
# TODO: move this hardcoded value to config
|
||||
# single layer has 3 modulation vectors
|
||||
# double layer has 6 modulation vectors for each expert
|
||||
# final layer has 2 modulation vectors
|
||||
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
|
||||
self.depth_single_blocks = params.depth_single_blocks
|
||||
self.depth_double_blocks = params.depth
|
||||
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
|
||||
self.register_buffer(
|
||||
"mod_index",
|
||||
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
|
||||
persistent=False,
|
||||
)
|
||||
self.approximator_in_dim = params.approximator_in_dim
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor,
|
||||
attn_padding: int = 1,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
# TODO:
|
||||
# need to fix grad accumulation issue here for now it's in no grad mode
|
||||
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
|
||||
# the fan out operation here is deleting the backward graph
|
||||
# alternatively doing forward pass for every block manually is doable but slow
|
||||
# custom backward probably be better
|
||||
with torch.no_grad():
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = (
|
||||
torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
|
||||
)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
# compute mask
|
||||
# assume max seq length from the batched input
|
||||
|
||||
max_len = txt.shape[1]
|
||||
|
||||
# mask
|
||||
with torch.no_grad():
|
||||
txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding)
|
||||
txt_img_mask = torch.cat(
|
||||
[
|
||||
txt_mask_w_padding,
|
||||
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float()
|
||||
txt_img_mask = txt_img_mask[None, None, ...].repeat(txt.shape[0], self.num_heads, 1, 1).int().bool()
|
||||
# txt_mask_w_padding[txt_mask_w_padding==False] = True
|
||||
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
# the guidance replaced by FFN output
|
||||
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
|
||||
double_mod = [img_mod, txt_mod]
|
||||
|
||||
# just in case in different GPU for simple pipeline parallel
|
||||
if self.training:
|
||||
img, txt = ckpt.checkpoint(block, img, txt, pe, double_mod, txt_img_mask)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
if self.training:
|
||||
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
|
||||
else:
|
||||
img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
||||
img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -75,7 +75,6 @@ class BaseSubsetParams:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -107,7 +106,7 @@ class BaseDatasetParams:
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -197,7 +196,6 @@ class ConfigSanitizer:
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -243,7 +241,6 @@ class ConfigSanitizer:
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -528,7 +525,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
@@ -562,7 +558,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
resize_interpolation: {subset.resize_interpolation}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from typing import Optional, Union, Callable, Tuple
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
|
||||
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
weight_swap_jobs = []
|
||||
|
||||
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.Stream(device="cuda")
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
@@ -66,24 +66,23 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
"""
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device(device)
|
||||
synchronize_device()
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device(device)
|
||||
synchronize_device()
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
@@ -149,16 +148,13 @@ class Offloader:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
|
||||
# Gradient tensors
|
||||
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
self.remove_handles = []
|
||||
@@ -172,7 +168,7 @@ class ModelOffloader(Offloader):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
@@ -186,7 +182,7 @@ class ModelOffloader(Offloader):
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
if self.debug:
|
||||
print(f"Backward hook for block {block_index}")
|
||||
|
||||
@@ -198,7 +194,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -211,7 +207,7 @@ class ModelOffloader(Offloader):
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
@@ -221,7 +217,7 @@ class ModelOffloader(Offloader):
|
||||
return
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
|
||||
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
if block_idx >= self.blocks_to_swap:
|
||||
|
||||
58
library/dataset_metadata_utils.py
Normal file
58
library/dataset_metadata_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
METADATA_VERSION = [1, 0, 0]
|
||||
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
|
||||
|
||||
ARCHIVE_PATH_SEPARATOR = "////"
|
||||
|
||||
|
||||
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
|
||||
if os.path.exists(metadata_file):
|
||||
logger.info(f"loading metadata file: {metadata_file}")
|
||||
with open(metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
if not create_new:
|
||||
return None
|
||||
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
|
||||
metadata = {"format_version": VERSION_STRING, "images": {}}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def is_archive_path(archive_and_image_path: str) -> bool:
|
||||
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
|
||||
|
||||
|
||||
def get_inner_path(archive_and_image_path: str) -> str:
|
||||
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
|
||||
|
||||
|
||||
def get_archive_digest(archive_and_image_path: str) -> str:
|
||||
"""
|
||||
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
|
||||
"""
|
||||
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
|
||||
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"
|
||||
@@ -5,8 +5,6 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -96,7 +94,6 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -125,56 +122,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||
else:
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -2,13 +2,6 @@ import functools
|
||||
import gc
|
||||
|
||||
import torch
|
||||
try:
|
||||
# intel gpu support for pytorch older than 2.5
|
||||
# ipex is not needed after pytorch 2.5
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
@@ -21,6 +14,8 @@ except Exception:
|
||||
HAS_MPS = False
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
|
||||
HAS_XPU = torch.xpu.is_available()
|
||||
except Exception:
|
||||
HAS_XPU = False
|
||||
@@ -74,7 +69,7 @@ def init_ipex():
|
||||
|
||||
This function should run right after importing torch and before doing anything else.
|
||||
|
||||
If xpu is not available, this function does nothing.
|
||||
If IPEX is not available, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
if HAS_XPU:
|
||||
|
||||
@@ -977,10 +977,10 @@ class Flux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = nn.ModuleList()
|
||||
self.single_blocks = nn.ModuleList()
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def sample_images(
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
controlnet=None
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
@@ -67,7 +67,7 @@ def sample_images(
|
||||
# unwrap unet and text_encoder(s)
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
if text_encoders is not None:
|
||||
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
if controlnet is not None:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
@@ -101,7 +101,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
controlnet
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
@@ -125,7 +125,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
controlnet
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@@ -147,16 +147,14 @@ def sample_image_inference(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
controlnet
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -164,8 +162,8 @@ def sample_image_inference(
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
@@ -175,21 +173,16 @@ def sample_image_inference(
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"CFG scale: {cfg_scale}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
@@ -198,37 +191,26 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
def encode_prompt(prpt):
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prpt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
return text_encoder_conds
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
|
||||
# encode negative prompts
|
||||
if cfg_scale != 1.0:
|
||||
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
|
||||
neg_t5_attn_mask = (
|
||||
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
|
||||
)
|
||||
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
|
||||
else:
|
||||
neg_cond = None
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -253,20 +235,7 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance=emb_guidance_scale,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
controlnet=controlnet,
|
||||
controlnet_img=controlnet_image,
|
||||
neg_cond=neg_cond,
|
||||
)
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -336,24 +305,22 @@ def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor, # t5_out
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor, # l_pooled
|
||||
vec: torch.Tensor,
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
do_cfg = neg_cond is not None
|
||||
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
@@ -369,48 +336,20 @@ def denoise(
|
||||
else:
|
||||
block_samples = None
|
||||
block_single_samples = None
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
if not do_cfg:
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
else:
|
||||
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
|
||||
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
|
||||
# TODO is it ok to use the same block samples for both cond and uncond?
|
||||
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
|
||||
block_single_samples = (
|
||||
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
|
||||
)
|
||||
|
||||
nc_c_pred = model(
|
||||
img=torch.cat([img, img], dim=0),
|
||||
img_ids=torch.cat([img_ids, img_ids], dim=0),
|
||||
txt=torch.cat([neg_t5_out, txt], dim=0),
|
||||
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
pred = neg_pred + (pred - neg_pred) * cfg_scale
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
@@ -427,6 +366,8 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
|
||||
@@ -469,34 +410,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
sigmas = None
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random sigma-based noise sampling
|
||||
# Simple random t-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
t = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = sigmas * num_timesteps
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -507,24 +456,12 @@ def get_noisy_model_input_and_timesteps(
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * num_timesteps).long()
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
@@ -630,7 +567,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
import contextlib
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
has_ipex = True
|
||||
except Exception:
|
||||
has_ipex = False
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
torch_version = float(torch.__version__[:3])
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
@@ -17,16 +12,6 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||
return True, "Skipping IPEX hijack"
|
||||
else:
|
||||
try:
|
||||
# force xpu device on torch compile and triton
|
||||
# import inductor utils to get around lazy import
|
||||
from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401
|
||||
torch._inductor.utils.GPU_TYPES = ["xpu"]
|
||||
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
|
||||
from triton import backends as triton_backends # pylint: disable=import-error
|
||||
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
|
||||
except Exception:
|
||||
pass
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
@@ -39,103 +24,86 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.nn.Module.cuda = torch.nn.Module.xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
if torch_version < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
else:
|
||||
torch.cuda._initialization_lock = torch.xpu._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu._queued_calls
|
||||
torch.cuda._tls = torch.xpu._tls
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
if torch_version < 2.5:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
|
||||
if torch_version < 2.7:
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.List = torch.xpu.List
|
||||
|
||||
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
|
||||
if has_ipex:
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
@@ -159,45 +127,53 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
# C
|
||||
if torch_version < 2.3:
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
ipex._C._DeviceProperties.minor = 1
|
||||
ipex._C._DeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
|
||||
else:
|
||||
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
|
||||
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
|
||||
torch._C._XpuDeviceProperties.major = 12
|
||||
torch._C._XpuDeviceProperties.minor = 1
|
||||
torch._C._XpuDeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 2024
|
||||
ipex._C._DeviceProperties.minor = 0
|
||||
|
||||
# Fix functions with ipex:
|
||||
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True)
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"])
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.get_device_properties.L2_cache_size = 16*1024*1024 # A770 and A750
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
device_supports_fp64 = ipex_hijacks()
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers(device_supports_fp64=device_supports_fp64)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
torch.cuda.is_xpu_hijacked = True
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
||||
@@ -1,119 +1,177 @@
|
||||
import os
|
||||
import torch
|
||||
from functools import cache, wraps
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from functools import cache
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
||||
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
# Find something divisible with the input_tokens
|
||||
@cache
|
||||
def find_split_size(original_size, slice_block_size, slice_rate=2):
|
||||
split_size = original_size
|
||||
while True:
|
||||
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
|
||||
return split_size
|
||||
split_size = split_size - 1
|
||||
if split_size <= 1:
|
||||
return 1
|
||||
return split_size
|
||||
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
# Find slice sizes for SDPA
|
||||
@cache
|
||||
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
|
||||
batch_size, attn_heads, query_len, _ = query_shape
|
||||
_, _, key_len, _ = key_shape
|
||||
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
|
||||
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_batch_size = batch_size
|
||||
split_head_size = attn_heads
|
||||
split_query_size = query_len
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_batch_split = False
|
||||
do_head_split = False
|
||||
do_query_split = False
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if batch_size * slice_batch_size >= trigger_rate:
|
||||
do_batch_split = True
|
||||
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
|
||||
if block_size > sdpa_slice_trigger_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
if split_batch_size * slice_batch_size > slice_rate:
|
||||
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
do_head_split = True
|
||||
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if split_head_size * slice_head_size > slice_rate:
|
||||
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
do_query_split = True
|
||||
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
|
||||
# Find slice sizes for BMM
|
||||
@cache
|
||||
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = input_tokens
|
||||
split_3_slice_size = mat2_atten_shape
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
if input.device.type != "xpu":
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||
|
||||
# Slice BMM
|
||||
if do_split:
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||
input[start_idx:end_idx],
|
||||
mat2[start_idx:end_idx],
|
||||
out=out
|
||||
)
|
||||
torch.xpu.synchronize(input.device)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
is_unsqueezed = False
|
||||
if query.dim() == 3:
|
||||
query = query.unsqueeze(0)
|
||||
is_unsqueezed = True
|
||||
if key.dim() == 3:
|
||||
key = key.unsqueeze(0)
|
||||
if value.dim() == 3:
|
||||
value = value.unsqueeze(0)
|
||||
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||
|
||||
# Slice SDPA
|
||||
if do_batch_split:
|
||||
batch_size, attn_heads, query_len, _ = query.shape
|
||||
_, _, _, head_dim = value.shape
|
||||
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
|
||||
for ib in range(batch_size // split_batch_size):
|
||||
start_idx = ib * split_batch_size
|
||||
end_idx = (ib + 1) * split_batch_size
|
||||
if do_head_split:
|
||||
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
|
||||
start_idx_h = ih * split_head_size
|
||||
end_idx_h = (ih + 1) * split_head_size
|
||||
if do_query_split:
|
||||
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
|
||||
start_idx_q = iq * split_query_size
|
||||
end_idx_q = (iq + 1) * split_query_size
|
||||
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
|
||||
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
|
||||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, :, :, :],
|
||||
key[start_idx:end_idx, :, :, :],
|
||||
value[start_idx:end_idx, :, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
if is_unsqueezed:
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
@@ -1,126 +1,312 @@
|
||||
from functools import wraps
|
||||
import os
|
||||
import torch
|
||||
import diffusers # pylint: disable=import-error
|
||||
from diffusers.utils import torch_utils # pylint: disable=import-error, unused-import # noqa: F401
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
from functools import cache
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
# Diffusers FreeU
|
||||
# Diffusers is imported before ipex hijacks so fourier_filter needs hijacking too
|
||||
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
||||
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
||||
def fourier_filter(x_in, threshold, scale):
|
||||
return_dtype = x_in.dtype
|
||||
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
|
||||
# fp64 error
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
def __init__(self, theta: int, axes_dim):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
for i in range(n_axes):
|
||||
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=torch.float32,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def hidream_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
return_device = pos.device
|
||||
pos = pos.to("cpu")
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.to(return_device, dtype=torch.float32)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
||||
if output_type == "np":
|
||||
return diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float32)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.outer(pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, use_real: bool = True, use_real_unbind_dim: int = -1):
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
@cache
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
# used for lumina
|
||||
# force cpu with Alchemist
|
||||
x_rotated = torch.view_as_complex(x.to("cpu").float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.to("cpu").unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x).to(x.device)
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
if slice_size is not None:
|
||||
batch_size_attention = slice_size
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if query_device_type != "xpu":
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, shape_three = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
####################################################################
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def ipex_diffusers(device_supports_fp64=False):
|
||||
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
||||
if not device_supports_fp64:
|
||||
# get around lazy imports
|
||||
from diffusers.models import embeddings as diffusers_embeddings # pylint: disable=import-error, unused-import # noqa: F401
|
||||
from diffusers.models import transformers as diffusers_transformers # pylint: disable=import-error, unused-import # noqa: F401
|
||||
from diffusers.models import controlnets as diffusers_controlnets # pylint: disable=import-error, unused-import # noqa: F401
|
||||
diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid = get_1d_sincos_pos_embed_from_grid
|
||||
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
|
||||
diffusers.models.embeddings.apply_rotary_emb = apply_rotary_emb
|
||||
diffusers.models.transformers.transformer_flux.FluxPosEmbed = FluxPosEmbed
|
||||
diffusers.models.transformers.transformer_lumina2.apply_rotary_emb = apply_rotary_emb
|
||||
diffusers.models.controlnets.controlnet_flux.FluxPosEmbed = FluxPosEmbed
|
||||
diffusers.models.transformers.transformer_hidream_image.rope = hidream_rope
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None,
|
||||
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||
|
||||
if do_split:
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
####################################################################
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ipex_diffusers():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||
|
||||
183
library/ipex/gradscaler.py
Normal file
183
library/ipex/gradscaler.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
# sync grad to master weight
|
||||
if hasattr(optimizer, "sync_grad"):
|
||||
optimizer.sync_grad()
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
if param.grad.is_sparse:
|
||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||
# so we should check the coalesced _values().
|
||||
if param.grad.dtype is torch.float16:
|
||||
param.grad = param.grad.coalesce()
|
||||
to_unscale = param.grad._values()
|
||||
else:
|
||||
to_unscale = param.grad
|
||||
|
||||
# -: is there a way to split by device and dtype without appending in the inner loop?
|
||||
to_unscale = to_unscale.to("cpu")
|
||||
per_device_and_dtype_grads[to_unscale.device][
|
||||
to_unscale.dtype
|
||||
].append(to_unscale)
|
||||
|
||||
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
core._amp_foreach_non_finite_check_and_unscale_(
|
||||
grads,
|
||||
per_device_found_inf.get("cpu"),
|
||||
per_device_inv_scale.get("cpu"),
|
||||
)
|
||||
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
||||
raise RuntimeError(
|
||||
"unscale_() has already been called on this optimizer since the last update()."
|
||||
)
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
if device_supports_fp64:
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
else:
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||
)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, inv_scale, found_inf, False
|
||||
)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
Args:
|
||||
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [
|
||||
found_inf.to(device="cpu", non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
to_device = _scale.device
|
||||
_scale = _scale.to("cpu")
|
||||
_growth_tracker = _growth_tracker.to("cpu")
|
||||
|
||||
core._amp_update_scale_(
|
||||
_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval,
|
||||
)
|
||||
|
||||
_scale = _scale.to(to_device)
|
||||
_growth_tracker = _growth_tracker.to(to_device)
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def gradscaler_init():
|
||||
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
||||
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
||||
torch.xpu.amp.GradScaler.update = update
|
||||
return torch.xpu.amp.GradScaler
|
||||
@@ -2,25 +2,10 @@ import os
|
||||
from functools import wraps
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import numpy as np
|
||||
|
||||
torch_version = float(torch.__version__[:3])
|
||||
current_xpu_device = f"xpu:{torch.xpu.current_device()}"
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(current_xpu_device).has_fp64
|
||||
|
||||
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0':
|
||||
if (torch.xpu.get_device_properties(current_xpu_device).total_memory / 1024 / 1024 / 1024) > 4.1:
|
||||
try:
|
||||
x = torch.ones((33000,33000), dtype=torch.float32, device=current_xpu_device)
|
||||
del x
|
||||
torch.xpu.empty_cache()
|
||||
use_dynamic_attention = False
|
||||
except Exception:
|
||||
use_dynamic_attention = True
|
||||
else:
|
||||
use_dynamic_attention = True
|
||||
else:
|
||||
use_dynamic_attention = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '1')
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
@@ -28,71 +13,36 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
return module.to(f"xpu:{torch.xpu.current_device()}")
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == "xpu" or self.device.type == "cuda"
|
||||
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
||||
|
||||
def check_device_type(device, device_type: str) -> bool:
|
||||
if device is None or type(device) not in {str, int, torch.device}:
|
||||
return False
|
||||
else:
|
||||
return bool(torch.device(device).type == device_type)
|
||||
def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
|
||||
def check_cuda(device) -> bool:
|
||||
return bool(isinstance(device, int) or check_device_type(device, "cuda"))
|
||||
|
||||
def return_xpu(device): # keep the device instance type, aka return string if the input is string
|
||||
return f"xpu:{torch.xpu.current_device()}" if device is None else f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
|
||||
# Autocast
|
||||
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
||||
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
||||
def autocast_init(self, device_type=None, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type is None or check_cuda(device_type):
|
||||
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type == "cuda":
|
||||
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
else:
|
||||
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
|
||||
|
||||
original_grad_scaler_init = torch.amp.grad_scaler.GradScaler.__init__
|
||||
@wraps(torch.amp.grad_scaler.GradScaler.__init__)
|
||||
def GradScaler_init(self, device: str = None, init_scale: float = 2.0**16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True):
|
||||
if device is None or check_cuda(device):
|
||||
return original_grad_scaler_init(self, device=return_xpu(device), init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
|
||||
else:
|
||||
return original_grad_scaler_init(self, device=device, init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
|
||||
|
||||
|
||||
original_is_autocast_enabled = torch.is_autocast_enabled
|
||||
@wraps(torch.is_autocast_enabled)
|
||||
def torch_is_autocast_enabled(device_type=None):
|
||||
if device_type is None or check_cuda(device_type):
|
||||
return original_is_autocast_enabled(return_xpu(device_type))
|
||||
else:
|
||||
return original_is_autocast_enabled(device_type)
|
||||
|
||||
|
||||
original_get_autocast_dtype = torch.get_autocast_dtype
|
||||
@wraps(torch.get_autocast_dtype)
|
||||
def torch_get_autocast_dtype(device_type=None):
|
||||
if device_type is None or check_cuda(device_type) or check_device_type(device_type, "xpu"):
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return original_get_autocast_dtype(device_type)
|
||||
|
||||
|
||||
# Latent Antialias CPU Offload:
|
||||
# IPEX 2.5 and above has partial support but doesn't really work most of the time.
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if mode in {'bicubic', 'bilinear'}:
|
||||
if antialias or align_corners is not None or mode == 'bicubic':
|
||||
return_device = tensor.device
|
||||
return_dtype = tensor.dtype
|
||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||
@@ -107,61 +57,51 @@ original_from_numpy = torch.from_numpy
|
||||
@wraps(torch.from_numpy)
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype("float32"))
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
else:
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
original_as_tensor = torch.as_tensor
|
||||
@wraps(torch.as_tensor)
|
||||
def as_tensor(data, dtype=None, device=None):
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not check_device_type(device, "cpu"):
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
||||
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
||||
return original_as_tensor(data, dtype=torch.float32, device=device)
|
||||
else:
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if not use_dynamic_attention:
|
||||
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 32 bit attention workarounds for Alchemist:
|
||||
try:
|
||||
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
|
||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
|
||||
# Data Type Errors:
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
|
||||
# Data Type Errors:
|
||||
original_torch_bmm = torch.bmm
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(dtype=input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
# Diffusers FreeU
|
||||
original_fft_fftn = torch.fft.fftn
|
||||
@wraps(torch.fft.fftn)
|
||||
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||
return_dtype = input.dtype
|
||||
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||
|
||||
# Diffusers FreeU
|
||||
original_fft_ifftn = torch.fft.ifftn
|
||||
@wraps(torch.fft.ifftn)
|
||||
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||
return_dtype = input.dtype
|
||||
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
@@ -193,15 +133,6 @@ def functional_linear(input, weight, bias=None):
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv1d = torch.nn.functional.conv1d
|
||||
@wraps(torch.nn.functional.conv1d)
|
||||
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
@wraps(torch.nn.functional.conv2d)
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
@@ -211,15 +142,14 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# LTX Video
|
||||
original_functional_conv3d = torch.nn.functional.conv3d
|
||||
@wraps(torch.nn.functional.conv3d)
|
||||
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
@wraps(torch.cat)
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
@@ -234,37 +164,38 @@ def functional_pad(input, pad, mode='constant', value=None):
|
||||
original_torch_tensor = torch.tensor
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
if check_device_type(device, "xpu"):
|
||||
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
torch.Tensor.original_Tensor_to = torch.Tensor.to
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_cuda(device):
|
||||
return self.original_Tensor_to(return_xpu(device), *args, **kwargs)
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return self.original_Tensor_to(device, *args, **kwargs)
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
@wraps(torch.Tensor.cuda)
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if device is None or check_cuda(device):
|
||||
return self.to(return_xpu(device), *args, **kwargs)
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
||||
@wraps(torch.Tensor.pin_memory)
|
||||
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||
if device is None or check_cuda(device):
|
||||
if device is None:
|
||||
device = "xpu"
|
||||
if check_device(device):
|
||||
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
||||
@@ -272,32 +203,23 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
if torch_version >= 2.4:
|
||||
original_UntypedStorage_to = torch.UntypedStorage.to
|
||||
@wraps(torch.UntypedStorage.to)
|
||||
def UntypedStorage_to(self, *args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
return original_UntypedStorage_to(self, *args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_to(self, *args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
|
||||
if device is None or check_cuda(device):
|
||||
return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
@@ -305,9 +227,9 @@ def torch_empty(*args, device=None, **kwargs):
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
if dtype is bytes:
|
||||
if dtype == bytes:
|
||||
dtype = None
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
@@ -315,7 +237,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
original_torch_ones = torch.ones
|
||||
@wraps(torch.ones)
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
@@ -323,144 +245,69 @@ def torch_ones(*args, device=None, **kwargs):
|
||||
original_torch_zeros = torch.zeros
|
||||
@wraps(torch.zeros)
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_full = torch.full
|
||||
@wraps(torch.full)
|
||||
def torch_full(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
return original_torch_full(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_full(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if check_device(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_eye = torch.eye
|
||||
@wraps(torch.eye)
|
||||
def torch_eye(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
return original_torch_eye(*args, device=return_xpu(device), **kwargs)
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_eye(*args, device=device, **kwargs)
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, *args, **kwargs):
|
||||
if map_location is None or check_cuda(map_location):
|
||||
if map_location is None:
|
||||
map_location = "xpu"
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||
|
||||
@wraps(torch.cuda.synchronize)
|
||||
def torch_cuda_synchronize(device=None):
|
||||
if check_cuda(device):
|
||||
return torch.xpu.synchronize(return_xpu(device))
|
||||
else:
|
||||
return torch.xpu.synchronize(device)
|
||||
|
||||
@wraps(torch.cuda.device)
|
||||
def torch_cuda_device(device):
|
||||
if check_cuda(device):
|
||||
return torch.xpu.device(return_xpu(device))
|
||||
else:
|
||||
return torch.xpu.device(device)
|
||||
|
||||
@wraps(torch.cuda.set_device)
|
||||
def torch_cuda_set_device(device):
|
||||
if check_cuda(device):
|
||||
torch.xpu.set_device(return_xpu(device))
|
||||
else:
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
# torch.Generator has to be a class for isinstance checks
|
||||
original_torch_Generator = torch.Generator
|
||||
class torch_Generator(original_torch_Generator):
|
||||
def __new__(self, device=None):
|
||||
# can't hijack __init__ because of C override so use return super().__new__
|
||||
if check_cuda(device):
|
||||
return super().__new__(self, return_xpu(device))
|
||||
else:
|
||||
return super().__new__(self, device)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
global device_supports_fp64
|
||||
if torch_version >= 2.4:
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.UntypedStorage.to = UntypedStorage_to
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.Tensor.pin_memory = Tensor_pin_memory
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.empty = torch_empty
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.full = torch_full
|
||||
torch.linspace = torch_linspace
|
||||
torch.eye = torch_eye
|
||||
torch.load = torch_load
|
||||
torch.cuda.synchronize = torch_cuda_synchronize
|
||||
torch.cuda.device = torch_cuda_device
|
||||
torch.cuda.set_device = torch_cuda_set_device
|
||||
|
||||
torch.Generator = torch_Generator
|
||||
torch._C.Generator = torch_Generator
|
||||
torch.load = torch_load
|
||||
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
||||
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
torch.nn.functional.layer_norm = functional_layer_norm
|
||||
torch.nn.functional.linear = functional_linear
|
||||
torch.nn.functional.conv1d = functional_conv1d
|
||||
torch.nn.functional.conv2d = functional_conv2d
|
||||
torch.nn.functional.conv3d = functional_conv3d
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.pad = functional_pad
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.fft.fftn = fft_fftn
|
||||
torch.fft.ifftn = fft_ifftn
|
||||
torch.cat = torch_cat
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
|
||||
# AMP:
|
||||
torch.amp.grad_scaler.GradScaler.__init__ = GradScaler_init
|
||||
torch.is_autocast_enabled = torch_is_autocast_enabled
|
||||
torch.get_autocast_gpu_dtype = torch_get_autocast_dtype
|
||||
torch.get_autocast_dtype = torch_get_autocast_dtype
|
||||
|
||||
if hasattr(torch.xpu, "amp"):
|
||||
if not hasattr(torch.xpu.amp, "custom_fwd"):
|
||||
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
if not hasattr(torch.xpu.amp, "GradScaler"):
|
||||
torch.xpu.amp.GradScaler = torch.amp.grad_scaler.GradScaler
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
else:
|
||||
if not hasattr(torch.amp, "custom_fwd"):
|
||||
torch.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
torch.cuda.amp = torch.amp
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
return device_supports_fp64
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
|
||||
# Added partial read support for up to 200x speedup
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
class JXLBitstream:
|
||||
"""
|
||||
A stream of bits with methods for easy handling.
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
||||
self.shift = 0
|
||||
self.bitstream = bytearray()
|
||||
self.file = file
|
||||
self.offset = offset
|
||||
self.offsets = offsets
|
||||
if self.offsets:
|
||||
self.offset = self.offsets[0][1]
|
||||
self.previous_data_len = 0
|
||||
self.index = 0
|
||||
self.file.seek(self.offset)
|
||||
|
||||
def get_bits(self, length: int = 1) -> int:
|
||||
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_to_read_length = length
|
||||
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(0, length)
|
||||
self.bitstream.extend(self.file.read(self.partial_to_read_length))
|
||||
else:
|
||||
self.bitstream.extend(self.file.read(length))
|
||||
bitmask = 2**length - 1
|
||||
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
|
||||
self.shift += length
|
||||
return bits
|
||||
|
||||
def partial_read(self, current_length: int, length: int) -> None:
|
||||
self.previous_data_len += self.offsets[self.index][2]
|
||||
to_read_length = self.previous_data_len - (self.shift + current_length)
|
||||
self.bitstream.extend(self.file.read(to_read_length))
|
||||
current_length += to_read_length
|
||||
self.partial_to_read_length -= to_read_length
|
||||
self.index += 1
|
||||
self.file.seek(self.offsets[self.index][1])
|
||||
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(current_length, length)
|
||||
|
||||
|
||||
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
||||
"""
|
||||
Decodes the actual codestream.
|
||||
JXL codestream specification: http://www-internal/2022/18181-1
|
||||
"""
|
||||
|
||||
# Convert codestream to int within an object to get some handy methods.
|
||||
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
|
||||
|
||||
# Skip signature
|
||||
codestream.get_bits(16)
|
||||
|
||||
# SizeHeader
|
||||
div8 = codestream.get_bits(1)
|
||||
if div8:
|
||||
height = 8 * (1 + codestream.get_bits(5))
|
||||
else:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
height = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
height = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
height = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
height = 1 + codestream.get_bits(30)
|
||||
ratio = codestream.get_bits(3)
|
||||
if div8 and not ratio:
|
||||
width = 8 * (1 + codestream.get_bits(5))
|
||||
elif not ratio:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
width = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
width = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
width = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
width = 1 + codestream.get_bits(30)
|
||||
else:
|
||||
match ratio:
|
||||
case 1:
|
||||
width = height
|
||||
case 2:
|
||||
width = (height * 12) // 10
|
||||
case 3:
|
||||
width = (height * 4) // 3
|
||||
case 4:
|
||||
width = (height * 3) // 2
|
||||
case 5:
|
||||
width = (height * 16) // 9
|
||||
case 6:
|
||||
width = (height * 5) // 4
|
||||
case 7:
|
||||
width = (height * 2) // 1
|
||||
return width, height
|
||||
|
||||
|
||||
def decode_container(file) -> Tuple[int,int]:
|
||||
"""
|
||||
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
||||
JXL container specification: http://www-internal/2022/18181-2
|
||||
"""
|
||||
|
||||
def parse_box(file, file_start: int) -> dict:
|
||||
file.seek(file_start)
|
||||
LBox = int.from_bytes(file.read(4), "big")
|
||||
XLBox = None
|
||||
if 1 < LBox <= 8:
|
||||
raise ValueError(f"Invalid LBox at byte {file_start}.")
|
||||
if LBox == 1:
|
||||
file.seek(file_start + 8)
|
||||
XLBox = int.from_bytes(file.read(8), "big")
|
||||
if XLBox <= 16:
|
||||
raise ValueError(f"Invalid XLBox at byte {file_start}.")
|
||||
if XLBox:
|
||||
header_length = 16
|
||||
box_length = XLBox
|
||||
else:
|
||||
header_length = 8
|
||||
if LBox == 0:
|
||||
box_length = os.fstat(file.fileno()).st_size - file_start
|
||||
else:
|
||||
box_length = LBox
|
||||
file.seek(file_start + 4)
|
||||
box_type = file.read(4)
|
||||
file.seek(file_start)
|
||||
return {
|
||||
"length": box_length,
|
||||
"type": box_type,
|
||||
"offset": header_length,
|
||||
}
|
||||
|
||||
file.seek(0)
|
||||
# Reject files missing required boxes. These two boxes are required to be at
|
||||
# the start and contain no values, so we can manually check there presence.
|
||||
# Signature box. (Redundant as has already been checked.)
|
||||
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
|
||||
raise ValueError("Invalid signature box.")
|
||||
# File Type box.
|
||||
if file.read(20) != bytes.fromhex(
|
||||
"00000014 66747970 6A786C20 00000000 6A786C20"
|
||||
):
|
||||
raise ValueError("Invalid file type box.")
|
||||
|
||||
offset = 0
|
||||
offsets = []
|
||||
data_offset_not_found = True
|
||||
container_pointer = 32
|
||||
file_size = os.fstat(file.fileno()).st_size
|
||||
while data_offset_not_found:
|
||||
box = parse_box(file, container_pointer)
|
||||
match box["type"]:
|
||||
case b"jxlc":
|
||||
offset = container_pointer + box["offset"]
|
||||
data_offset_not_found = False
|
||||
case b"jxlp":
|
||||
file.seek(container_pointer + box["offset"])
|
||||
index = int.from_bytes(file.read(4), "big")
|
||||
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
|
||||
container_pointer += box["length"]
|
||||
if container_pointer >= file_size:
|
||||
data_offset_not_found = False
|
||||
|
||||
if offsets:
|
||||
offsets.sort(key=lambda i: i[0])
|
||||
file.seek(0)
|
||||
|
||||
return decode_codestream(file, offset=offset, offsets=offsets)
|
||||
|
||||
|
||||
def get_jxl_size(path: str) -> Tuple[int,int]:
|
||||
with open(path, "rb") as file:
|
||||
if file.read(2) == bytes.fromhex("FF0A"):
|
||||
return decode_codestream(file)
|
||||
return decode_container(file)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,233 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers import Gemma2Config, Gemma2Model
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library import lumina_models, flux_models
|
||||
from library.utils import load_safetensors
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_VERSION_LUMINA_V2 = "lumina2"
|
||||
|
||||
|
||||
def load_lumina_model(
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: torch.device,
|
||||
disable_mmap: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
use_sage_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
Load the Lumina model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (torch.device): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
|
||||
|
||||
Returns:
|
||||
model (lumina_models.NextDiT): The loaded model.
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
logger.info(f"Loaded Lumina: {info}")
|
||||
return model
|
||||
|
||||
|
||||
def load_ae(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
) -> flux_models.AutoEncoder:
|
||||
"""
|
||||
Load the AutoEncoder model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ae (flux_models.AutoEncoder): The loaded model.
|
||||
"""
|
||||
logger.info("Building AutoEncoder")
|
||||
with torch.device("meta"):
|
||||
# dev and schnell have the same AE params
|
||||
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
|
||||
|
||||
def load_gemma2(
|
||||
ckpt_path: Optional[str],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> Gemma2Model:
|
||||
"""
|
||||
Load the Gemma2 model from the checkpoint path.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the checkpoint.
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (Union[str, torch.device]): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
|
||||
|
||||
Returns:
|
||||
gemma2 (Gemma2Model): The loaded model
|
||||
"""
|
||||
logger.info("Building Gemma2")
|
||||
GEMMA2_CONFIG = {
|
||||
"_name_or_path": "google/gemma-2-2b",
|
||||
"architectures": ["Gemma2Model"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"bos_token_id": 2,
|
||||
"cache_implementation": "hybrid",
|
||||
"eos_token_id": 1,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"head_dim": 256,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2304,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9216,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "gemma2",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 26,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 0,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.44.2",
|
||||
"use_cache": True,
|
||||
"vocab_size": 256000,
|
||||
}
|
||||
|
||||
config = Gemma2Config(**GEMMA2_CONFIG)
|
||||
with init_empty_weights():
|
||||
gemma2 = Gemma2Model._from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = gemma2.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
|
||||
# Embedding layers
|
||||
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
|
||||
"time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
|
||||
"text_embedder.1.bias": "cap_embedder.1.bias",
|
||||
"patch_embedder.proj.weight": "x_embedder.weight",
|
||||
"patch_embedder.proj.bias": "x_embedder.bias",
|
||||
# Attention modulation
|
||||
"transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
|
||||
"transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
|
||||
# Final layers
|
||||
"final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
|
||||
"final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
|
||||
"final_linear.weight": "final_layer.linear.weight",
|
||||
"final_linear.bias": "final_layer.linear.bias",
|
||||
# Noise refiner
|
||||
"single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
|
||||
"single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
|
||||
"single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
|
||||
"single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
|
||||
# Normalization
|
||||
"transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
|
||||
"transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
|
||||
# FFN
|
||||
"transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
|
||||
"transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
|
||||
"transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
|
||||
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
|
||||
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
|
||||
new_sd = sd.copy() # Preserve original keys
|
||||
|
||||
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
|
||||
# Handle block-specific patterns
|
||||
if '().' in diff_key:
|
||||
for block_idx in range(num_double_blocks):
|
||||
block_alpha_key = alpha_key.replace('().', f'{block_idx}.')
|
||||
block_diff_key = diff_key.replace('().', f'{block_idx}.')
|
||||
|
||||
# Search for and convert block-specific keys
|
||||
for input_key, value in list(sd.items()):
|
||||
if input_key == block_diff_key:
|
||||
new_sd[block_alpha_key] = value
|
||||
else:
|
||||
# Handle static keys
|
||||
if diff_key in sd:
|
||||
print(f"Replacing {diff_key} with {alpha_key}")
|
||||
new_sd[alpha_key] = sd[diff_key]
|
||||
else:
|
||||
print(f"Not found: {diff_key}")
|
||||
|
||||
|
||||
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
|
||||
return new_sd
|
||||
@@ -643,15 +643,16 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||
|
||||
# remove position_ids for newer transformer, which causes error :(
|
||||
# rename or add position_ids
|
||||
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
||||
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
||||
# waifu diffusion v1.4
|
||||
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
else:
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
|
||||
if "text_model.embeddings.position_ids" in new_sd:
|
||||
del new_sd["text_model.embeddings.position_ids"]
|
||||
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
return new_sd
|
||||
|
||||
|
||||
|
||||
@@ -61,8 +61,6 @@ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
||||
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -71,7 +69,6 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -126,7 +123,6 @@ def build_metadata(
|
||||
clip_skip: Optional[int] = None,
|
||||
sd3: Optional[str] = None,
|
||||
flux: Optional[str] = None,
|
||||
lumina: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
sd3: only supports "m", flux: only supports "dev"
|
||||
@@ -150,11 +146,6 @@ def build_metadata(
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif lumina is not None:
|
||||
if lumina == "lumina2":
|
||||
arch = ARCH_LUMINA_2
|
||||
else:
|
||||
arch = ARCH_LUMINA_UNKNOWN
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
@@ -176,9 +167,6 @@ def build_metadata(
|
||||
if flux is not None:
|
||||
# Flux
|
||||
impl = IMPL_FLUX
|
||||
elif lumina is not None:
|
||||
# Lumina
|
||||
impl = IMPL_LUMINA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
@@ -237,7 +225,7 @@ def build_metadata(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl or sd3 is not None or flux is not None or lumina is not None:
|
||||
if sdxl or sd3 is not None or flux is not None:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
|
||||
@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
|
||||
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
||||
|
||||
self.offloader = custom_offloading_utils.ModelOffloader(
|
||||
self.joint_blocks, self.blocks_to_swap, device # , debug=True
|
||||
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
|
||||
)
|
||||
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
||||
|
||||
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
save_blocks = self.joint_blocks
|
||||
self.joint_blocks = nn.ModuleList()
|
||||
self.joint_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
|
||||
@@ -344,6 +344,8 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
@@ -2,16 +2,14 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union, Callable
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from safetensors.torch import safe_open, save_file
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -19,6 +17,81 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import dataset_metadata_utils, utils
|
||||
|
||||
|
||||
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
|
||||
if dtype is None:
|
||||
# all dtypes are acceptable
|
||||
return get_available_dtypes()
|
||||
|
||||
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
||||
compatible_dtypes = [torch.float32]
|
||||
if dtype.itemsize == 1: # fp8
|
||||
compatible_dtypes.append(torch.bfloat16)
|
||||
compatible_dtypes.append(torch.float16)
|
||||
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
|
||||
return compatible_dtypes
|
||||
|
||||
|
||||
def get_available_dtypes() -> List[torch.dtype]:
|
||||
"""
|
||||
Returns the list of available dtypes for latents caching. Higher precision is preferred.
|
||||
"""
|
||||
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
|
||||
|
||||
|
||||
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
|
||||
"""
|
||||
Removes lower precision values from tensor_dict.
|
||||
"""
|
||||
available_dtypes = get_available_dtypes()
|
||||
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
|
||||
|
||||
for key_without_dtype in keys_without_dtype:
|
||||
available_itemsize = None
|
||||
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
|
||||
key = key_without_dtype + dtype_suffix
|
||||
|
||||
if key in tensor_dict:
|
||||
if available_itemsize is None:
|
||||
available_itemsize = dtype.itemsize
|
||||
elif available_itemsize > dtype.itemsize:
|
||||
# if higher precision latents are already cached, remove lower precision latents
|
||||
del tensor_dict[key]
|
||||
|
||||
|
||||
def get_compatible_dtype_keys(
|
||||
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> list[Optional[str]]:
|
||||
"""
|
||||
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
|
||||
If the key is not found, it returns None.
|
||||
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
|
||||
|
||||
:param dict_keys: set of keys in the dictionary
|
||||
:param keys_without_dtype: list of keys without dtype suffix to check
|
||||
:param dtype: dtype to check, or None for any dtype
|
||||
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
|
||||
"""
|
||||
compatible_dtypes = get_compatible_dtypes(dtype)
|
||||
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
|
||||
|
||||
available_keys = []
|
||||
for key_without_dtype in keys_without_dtype:
|
||||
available_key = None
|
||||
if key_without_dtype in dict_keys:
|
||||
available_key = key_without_dtype
|
||||
else:
|
||||
for dtype_suffix in dtype_suffixes:
|
||||
key = key_without_dtype + dtype_suffix
|
||||
if key in dict_keys:
|
||||
available_key = key
|
||||
break
|
||||
available_keys.append(available_key)
|
||||
|
||||
return available_keys
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
@@ -324,17 +397,26 @@ class TextEncoderOutputsCachingStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
architecture: str,
|
||||
cache_to_disk: bool,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool = False,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
|
||||
"""
|
||||
self._architecture = architecture
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
self._max_token_length = max_token_length
|
||||
self._masked = masked
|
||||
self._is_partial = is_partial
|
||||
self._is_weighted = is_weighted
|
||||
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
@@ -346,6 +428,18 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def architecture(self):
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def max_token_length(self):
|
||||
return self._max_token_length
|
||||
|
||||
@property
|
||||
def masked(self):
|
||||
return self._masked
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -354,6 +448,11 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
suffix_masked = "_m" if self.masked else ""
|
||||
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
|
||||
|
||||
@property
|
||||
def is_partial(self):
|
||||
return self._is_partial
|
||||
@@ -362,31 +461,159 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def is_weighted(self):
|
||||
return self._is_weighted
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
def get_cache_path(self, absolute_path: str) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + self.cache_suffix
|
||||
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
|
||||
"""
|
||||
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
|
||||
all dtype tensors are returned, because cache validation is done in advance.
|
||||
"""
|
||||
with safe_open(cache_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, version.split("."))
|
||||
if major > 1: # or (major == 1 and minor > 0):
|
||||
if not self.load_version_warning_printed:
|
||||
self.load_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
|
||||
dict_keys = f.keys()
|
||||
results = []
|
||||
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
|
||||
for key in compatible_keys:
|
||||
results.append(f.get_tensor(key) if key is not None else None)
|
||||
|
||||
return results
|
||||
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
|
||||
"""
|
||||
masked: may be False even if self.masked is True. It is False for some outputs.
|
||||
"""
|
||||
key_suffix = f"_{prompt_id}"
|
||||
if dtype is not None and dtype.is_floating_point: # float tensor only
|
||||
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
|
||||
return key_suffix
|
||||
|
||||
def get_compatible_output_keys(
|
||||
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> list[Optional[str], Optional[str]]:
|
||||
"""
|
||||
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
|
||||
"""
|
||||
key_suffix = self.get_key_suffix(caption_index, None)
|
||||
keys_without_dtype = [k + key_suffix for k in base_keys]
|
||||
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
|
||||
|
||||
def _default_is_disk_cached_outputs_expected(
|
||||
self,
|
||||
cache_path: str,
|
||||
captions: list[str],
|
||||
base_keys: list[tuple[str, bool]],
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(cache_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
keys = f.keys()
|
||||
metadata = f.metadata()
|
||||
|
||||
# check captions in metadata
|
||||
for i, caption in enumerate(captions):
|
||||
if metadata.get(f"caption{i+1}") != caption:
|
||||
return False
|
||||
|
||||
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
|
||||
if any(key is None for key in compatible_keys):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {cache_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: list[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
|
||||
tensor_dict = {}
|
||||
|
||||
overwrite = False
|
||||
if os.path.exists(cache_path):
|
||||
# load existing safetensors and update it
|
||||
overwrite = True
|
||||
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
metadata = f.metadata()
|
||||
keys = f.keys()
|
||||
for key in keys:
|
||||
tensor_dict[key] = f.get_tensor(key)
|
||||
assert metadata["architecture"] == self.architecture
|
||||
|
||||
file_version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, file_version.split("."))
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
self.save_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
else:
|
||||
metadata = {}
|
||||
metadata["architecture"] = self.architecture
|
||||
metadata["format_version"] = "1.0.0"
|
||||
|
||||
metadata[f"caption{caption_index+1}"] = caption
|
||||
|
||||
for key, output in zip(keys, outputs):
|
||||
dtype = output.dtype # long or one of float
|
||||
key_suffix = self.get_key_suffix(caption_index, dtype)
|
||||
tensor_dict[key + key_suffix] = output
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
suffix_without_dtype = self.get_key_suffix(caption_index, None)
|
||||
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
def __init__(
|
||||
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
) -> None:
|
||||
self._architecture = architecture
|
||||
self._latents_stride = latents_stride
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
self.load_version_warning_printed = False
|
||||
self.save_version_warning_printed = False
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
@@ -397,6 +624,14 @@ class LatentsCachingStrategy:
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def architecture(self):
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def latents_stride(self):
|
||||
return self._latents_stride
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -407,66 +642,126 @@ class LatentsCachingStrategy:
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
raise NotImplementedError
|
||||
return f"_{self.architecture.lower()}.safetensors"
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
|
||||
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
|
||||
|
||||
def get_latents_cache_path(
|
||||
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
|
||||
) -> str:
|
||||
if cache_dir is not None:
|
||||
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
|
||||
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
|
||||
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
|
||||
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
|
||||
else:
|
||||
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
|
||||
else:
|
||||
cache_file_base = absolute_path_or_archive_img_path
|
||||
|
||||
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_key_suffix(
|
||||
self,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
if dtype is None, it returns "_32x64" for example.
|
||||
"""
|
||||
if latents_size is not None:
|
||||
expected_latents_size = latents_size # H, W
|
||||
else:
|
||||
# bucket_reso is (W, H)
|
||||
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
|
||||
|
||||
if dtype is None:
|
||||
dtype_suffix = ""
|
||||
else:
|
||||
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
|
||||
|
||||
# e.g. "_32x64_float16", HxW, dtype
|
||||
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
|
||||
|
||||
return key_suffix
|
||||
|
||||
def get_compatible_latents_keys(
|
||||
self,
|
||||
keys: set[str],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
flip_aug: bool,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
) -> list[Optional[str], Optional[str]]:
|
||||
"""
|
||||
bucket_reso is (W, H), latents_size is (H, W)
|
||||
"""
|
||||
|
||||
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
|
||||
keys_without_dtype = ["latents" + key_suffix]
|
||||
if flip_aug:
|
||||
keys_without_dtype.append("latents_flipped" + key_suffix)
|
||||
|
||||
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
|
||||
return compatible_keys if flip_aug else compatible_keys[0] + [None]
|
||||
|
||||
def _default_is_disk_cached_latents_expected(
|
||||
self,
|
||||
latents_stride: int,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
latents_cache_path: str,
|
||||
flip_aug: bool,
|
||||
apply_alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
latents_stride: stride of latents
|
||||
bucket_reso: resolution of the bucket
|
||||
npz_path: path to the npz file
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
):
|
||||
# multi_resolution is always enabled for any strategy
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
if not os.path.exists(latents_cache_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
# e.g. "_32x64", HxW
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
||||
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
# safe_open locks the file, so we cannot use it for checking keys
|
||||
# with safe_open(latents_cache_path, framework="pt") as f:
|
||||
# keys = f.keys()
|
||||
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
|
||||
keys = f.keys()
|
||||
|
||||
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
|
||||
# print(f"alpha_mask not found: {latents_cache_path}")
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
|
||||
# preferred_dtype is None if any dtype is acceptable
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
|
||||
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
|
||||
)
|
||||
if latents_key is None or (flip_aug and flipped_latents_key is None):
|
||||
# print(f"Precise dtype not found: {latents_cache_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
logger.error(f"Error loading file: {latents_cache_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
@@ -474,35 +769,21 @@ class LatentsCachingStrategy:
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self,
|
||||
encode_by_vae: Callable,
|
||||
vae_device: torch.device,
|
||||
vae_dtype: torch.dtype,
|
||||
image_infos: List,
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos: List[utils.ImageInfo],
|
||||
flip_aug: bool,
|
||||
apply_alpha_mask: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
|
||||
Args:
|
||||
encode_by_vae: function to encode images by VAE
|
||||
vae_device: device to use for VAE
|
||||
vae_dtype: dtype to use for VAE
|
||||
image_infos: list of ImageInfo
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
random_crop: whether to random crop images
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, apply_alpha_mask, random_crop
|
||||
image_infos, alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||
|
||||
@@ -524,13 +805,8 @@ class LatentsCachingStrategy:
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
latents_size = latents.shape[1:3] # H, W
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||
|
||||
if self.cache_to_disk:
|
||||
self.save_latents_to_disk(
|
||||
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
|
||||
)
|
||||
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||||
else:
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
@@ -540,97 +816,96 @@ class LatentsCachingStrategy:
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Args:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
with safe_open(cache_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, version.split("."))
|
||||
if major > 1: # or (major == 1 and minor > 0):
|
||||
if not self.load_version_warning_printed:
|
||||
self.load_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
keys = f.keys()
|
||||
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
|
||||
|
||||
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
|
||||
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
|
||||
|
||||
latents = f.get_tensor(latents_key)
|
||||
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
|
||||
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
|
||||
|
||||
original_size = [int(metadata["width"]), int(metadata["height"])]
|
||||
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
|
||||
crop_ltrb = list(map(int, crop_ltrb.split(",")))
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self,
|
||||
npz_path,
|
||||
latents_tensor,
|
||||
original_size,
|
||||
crop_ltrb,
|
||||
flipped_latents_tensor=None,
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
cache_path: str,
|
||||
latents_tensor: torch.Tensor,
|
||||
original_size: Tuple[int, int],
|
||||
crop_ltrb: List[int],
|
||||
flipped_latents_tensor: Optional[torch.Tensor] = None,
|
||||
alpha_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
latents_tensor (torch.Tensor): Latent tensor
|
||||
original_size (List[int]): Original size of the image
|
||||
crop_ltrb (List[int]): Crop left top right bottom
|
||||
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
|
||||
alpha_mask (Optional[torch.Tensor]): Alpha mask
|
||||
key_reso_suffix (str): Key resolution suffix
|
||||
dtype = latents_tensor.dtype
|
||||
latents_size = latents_tensor.shape[1:3] # H, W
|
||||
tensor_dict = {}
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
kwargs = {}
|
||||
overwrite = False
|
||||
if os.path.exists(cache_path):
|
||||
# load existing safetensors and update it
|
||||
overwrite = True
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
# load existing npz and update it
|
||||
npz = np.load(npz_path)
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
# we cannot use safe_open here because it locks the file
|
||||
# with safe_open(cache_path, framework="pt") as f:
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
metadata = f.metadata()
|
||||
keys = f.keys()
|
||||
for key in keys:
|
||||
tensor_dict[key] = f.get_tensor(key)
|
||||
assert metadata["architecture"] == self.architecture
|
||||
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
file_version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, file_version.split("."))
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
self.save_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
else:
|
||||
metadata = {}
|
||||
metadata["architecture"] = self.architecture
|
||||
metadata["width"] = f"{original_size[0]}"
|
||||
metadata["height"] = f"{original_size[1]}"
|
||||
metadata["format_version"] = "1.0.0"
|
||||
|
||||
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
|
||||
|
||||
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
|
||||
if latents_tensor is not None:
|
||||
tensor_dict["latents" + key_suffix] = latents_tensor
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
|
||||
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(npz_path, **kwargs)
|
||||
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
|
||||
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
|
||||
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
@@ -5,9 +5,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -15,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_utils, train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||
KEYS = ["l_pooled", "t5_out", "txt_ids"]
|
||||
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
super().__init__(
|
||||
FluxLatentsCachingStrategy.ARCHITECTURE,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked,
|
||||
is_partial,
|
||||
)
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
):
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
txt_ids = data["txt_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)[0]
|
||||
else:
|
||||
t5_attn_mask = None
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
@@ -154,80 +145,67 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
l_pooled = l_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
txt_ids = txt_ids.cpu()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu()
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
|
||||
if self.masked:
|
||||
outputs += [t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
||||
ARCHITECTURE = "flux"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
|
||||
from library import train_util
|
||||
from library.strategy_base import (
|
||||
LatentsCachingStrategy,
|
||||
TokenizeStrategy,
|
||||
TextEncodingStrategy,
|
||||
TextEncoderOutputsCachingStrategy,
|
||||
)
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GEMMA_ID = "google/gemma-2-2b"
|
||||
|
||||
|
||||
class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(
|
||||
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
|
||||
) -> None:
|
||||
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
|
||||
GEMMA_ID, cache_dir=tokenizer_cache_dir
|
||||
)
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = 256
|
||||
else:
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize(
|
||||
self, text: Union[str, List[str]], is_negative: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
text (Union[str, List[str]]): Text to tokenize
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
token input ids, attention_masks
|
||||
"""
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
# In training, we always add system prompt (is_negative=False)
|
||||
if not is_negative:
|
||||
# Add system prompt to the beginning of each text
|
||||
text = [self.system_prompt + t for t in text]
|
||||
|
||||
encodings = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
pad_to_multiple_of=8,
|
||||
)
|
||||
return (encodings.input_ids, encodings.attention_mask)
|
||||
|
||||
def tokenize_with_weights(
|
||||
self, text: str | List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text (Union[str, List[str]]): Text to tokenize
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
token input ids, attention_masks, weights
|
||||
"""
|
||||
# Gemma doesn't support weighted prompts, return uniform weights
|
||||
tokens, attention_masks = self.tokenize(text)
|
||||
weights = [torch.ones_like(t) for t in tokens]
|
||||
return tokens, attention_masks, weights
|
||||
|
||||
|
||||
class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states, input_ids, attention_masks
|
||||
"""
|
||||
text_encoder = models[0]
|
||||
# Check model or torch dynamo OptimizedModule
|
||||
assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}"
|
||||
input_ids, attention_masks = tokens
|
||||
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids.to(text_encoder.device),
|
||||
attention_mask=attention_masks.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return outputs.hidden_states[-2], input_ids, attention_masks
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: Tuple[torch.Tensor, torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
|
||||
weights_list (List[torch.Tensor]): Currently unused
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states, input_ids, attention_masks
|
||||
"""
|
||||
# For simplicity, use uniform weighting
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
|
||||
|
||||
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
is_partial,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return (
|
||||
os.path.splitext(image_abs_path)[0]
|
||||
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
|
||||
Returns:
|
||||
bool: True if the npz file is expected to be cached.
|
||||
"""
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state" not in npz:
|
||||
return False
|
||||
if "attention_mask" not in npz:
|
||||
return False
|
||||
if "input_ids" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Load outputs from a npz file
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: hidden_state, input_ids, attention_mask
|
||||
"""
|
||||
data = np.load(npz_path)
|
||||
hidden_state = data["hidden_state"]
|
||||
attention_mask = data["attention_mask"]
|
||||
input_ids = data["input_ids"]
|
||||
return [hidden_state, input_ids, attention_mask]
|
||||
|
||||
@torch.no_grad()
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: List[train_util.ImageInfo],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
text_encoding_strategy (LuminaTextEncodingStrategy):
|
||||
infos (List): List of ImageInfo
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
|
||||
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
|
||||
|
||||
captions = [info.caption for info in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens, attention_masks, weights_list = (
|
||||
tokenize_strategy.tokenize_with_weights(captions)
|
||||
)
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
(tokens, attention_masks),
|
||||
weights_list,
|
||||
)
|
||||
)
|
||||
else:
|
||||
tokens = tokenize_strategy.tokenize(captions)
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)
|
||||
)
|
||||
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
|
||||
hidden_state = hidden_state.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy() # (B, S)
|
||||
input_ids = input_ids.cpu().numpy() # (B, S)
|
||||
|
||||
|
||||
for i, info in enumerate(batch):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [
|
||||
hidden_state_i,
|
||||
input_ids_i,
|
||||
attention_mask_i,
|
||||
]
|
||||
|
||||
|
||||
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(
|
||||
self, absolute_path: str, image_size: Tuple[int, int]
|
||||
) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
npz_path (str): Path to the npz file.
|
||||
flip_aug (bool): Whether to flip the image.
|
||||
alpha_mask (bool): Whether to apply
|
||||
"""
|
||||
return self._default_is_disk_cached_latents_expected(
|
||||
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
||||
)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
|
||||
"""
|
||||
return self._default_load_latents_from_disk(
|
||||
8, npz_path, bucket_reso
|
||||
) # support multi-resolution
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(
|
||||
self,
|
||||
model,
|
||||
batch: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
):
|
||||
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
|
||||
vae_device = model.device
|
||||
vae_dtype = model.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
batch,
|
||||
flip_aug,
|
||||
alpha_mask,
|
||||
random_crop,
|
||||
multi_resolution=True,
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(model.device)
|
||||
@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -13,6 +11,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
|
||||
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
@@ -134,33 +134,30 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||
ARCHITECTURE_SD = "sd"
|
||||
ARCHITECTURE_SDXL = "sdxl"
|
||||
|
||||
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
|
||||
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.sd = sd
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
) -> bool:
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
def load_latents_from_disk(
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
@@ -6,10 +6,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -17,6 +13,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -254,7 +253,8 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||
KEYS = ["lg_out", "t5_out", "lg_pooled"]
|
||||
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -262,70 +262,51 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_lg_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
max_token_length: int = 256,
|
||||
masked: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
"""
|
||||
apply_lg_attn_mask and apply_t5_attn_mask must be same
|
||||
"""
|
||||
super().__init__(
|
||||
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked=masked,
|
||||
is_partial=is_partial,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "lg_out" not in npz:
|
||||
return False
|
||||
if "lg_pooled" not in npz:
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)
|
||||
else:
|
||||
l_attn_mask = g_attn_mask = t5_attn_mask = None
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
@@ -334,87 +315,76 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
apply_lg_attn_mask=self.masked,
|
||||
apply_t5_attn_mask=self.masked,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
lg_out = lg_out.cpu()
|
||||
lg_pooled = lg_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
l_attn_mask = tokens_and_masks[3].cpu()
|
||||
g_attn_mask = tokens_and_masks[4].cpu()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
|
||||
if self.masked:
|
||||
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [
|
||||
lg_out_i,
|
||||
t5_out_i,
|
||||
lg_pooled_i,
|
||||
l_attn_mask_i,
|
||||
g_attn_mask_i,
|
||||
t5_attn_mask_i,
|
||||
]
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
ARCHITECTURE_SD3 = "sd3"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -14,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
from library import utils
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -21,6 +21,9 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
|
||||
"""
|
||||
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
||||
@@ -220,51 +223,51 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
ARCHITECTURE_SDXL = "sdxl"
|
||||
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: Optional[int] = None,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
||||
"""
|
||||
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
|
||||
"""
|
||||
max_token_length = max_token_length or 75
|
||||
super().__init__(
|
||||
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
is_partial,
|
||||
is_weighted,
|
||||
max_token_length=max_token_length,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
# SDXL does not support attn mask
|
||||
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_state1 = data["hidden_state1"]
|
||||
hidden_state2 = data["hidden_state2"]
|
||||
pool2 = data["pool2"]
|
||||
return [hidden_state1, hidden_state2, pool2]
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
|
||||
@@ -279,28 +282,24 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy, models, [tokens1, tokens2]
|
||||
)
|
||||
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
hidden_state1 = hidden_state1.cpu()
|
||||
hidden_state2 = hidden_state2.cpu()
|
||||
pool2 = pool2.cpu()
|
||||
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
self.save_outputs_to_disk(
|
||||
info.text_encoder_outputs_cache_path,
|
||||
caption_index,
|
||||
caption,
|
||||
SdxlTextEncoderOutputsCachingStrategy.KEYS,
|
||||
[hidden_state1_i, hidden_state2_i, pool2_i],
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
184
library/utils.py
184
library/utils.py
@@ -16,10 +16,67 @@ from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.captions: Optional[list[str]] = None
|
||||
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
|
||||
self.list_of_tags: Optional[list[str]] = None
|
||||
self.tags_weights: Optional[list[float]] = None
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.latents_cache_dir: Optional[str] = None
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_cache_path: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
# crop left top right bottom in original pixel size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
|
||||
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
|
||||
# old
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
|
||||
|
||||
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
|
||||
self.list_of_tags = list_of_tags
|
||||
|
||||
def set_fine_tuning_info(
|
||||
self,
|
||||
captions: Optional[list[str]],
|
||||
caption_weights: Optional[list[float]],
|
||||
list_of_tags: Optional[list[str]],
|
||||
tags_weights: Optional[list[float]],
|
||||
image_size: Tuple[int, int],
|
||||
latents_cache_dir: Optional[str],
|
||||
):
|
||||
self.captions = captions
|
||||
self.caption_weights = caption_weights
|
||||
self.list_of_tags = list_of_tags
|
||||
self.tags_weights = tags_weights
|
||||
self.image_size = image_size
|
||||
self.latents_cache_dir = latents_cache_dir
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
@@ -88,8 +145,6 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -190,6 +245,15 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
|
||||
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
||||
|
||||
# get name of the dtype
|
||||
dtype_name = str(dtype).split(".")[-1]
|
||||
|
||||
return dtype_name
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
@@ -262,6 +326,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
@@ -379,7 +444,7 @@ def load_safetensors(
|
||||
# region Image utils
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation):
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
if has_alpha:
|
||||
@@ -387,7 +452,7 @@ def pil_resize(image, size, interpolation):
|
||||
else:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
resized_pil = pil_image.resize(size, resample=interpolation)
|
||||
resized_pil = pil_image.resize(size, interpolation)
|
||||
|
||||
# Convert back to cv2 format
|
||||
if has_alpha:
|
||||
@@ -398,117 +463,6 @@ def pil_resize(image, size, interpolation):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
Args:
|
||||
image: numpy.ndarray
|
||||
width: int Original image width
|
||||
height: int Original image height
|
||||
resized_width: int Resized image width
|
||||
resized_height: int Resized image height
|
||||
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
|
||||
|
||||
Returns:
|
||||
image
|
||||
"""
|
||||
|
||||
# Ensure all size parameters are actual integers
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
resized_width = int(resized_width)
|
||||
resized_height = int(resized_height)
|
||||
|
||||
if resize_interpolation is None:
|
||||
if width >= resized_width and height >= resized_height:
|
||||
resize_interpolation = "area"
|
||||
else:
|
||||
resize_interpolation = "lanczos"
|
||||
|
||||
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
|
||||
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
|
||||
|
||||
resized_size = (resized_width, resized_height)
|
||||
if use_pil:
|
||||
interpolation = get_pil_interpolation(resize_interpolation)
|
||||
image = pil_resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (PIL)")
|
||||
else:
|
||||
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||
image = cv2.resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (cv2)")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
"""
|
||||
Convert interpolation value to cv2 interpolation integer
|
||||
|
||||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
return cv2.INTER_NEAREST_EXACT
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# bilinear interpolation
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# bicubic interpolation
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
elif interpolation == "box":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
|
||||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return Image.Resampling.LANCZOS
|
||||
elif interpolation == "nearest":
|
||||
# Pick one nearest pixel from the input image. Ignore all other input pixels.
|
||||
return Image.Resampling.NEAREST
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
|
||||
return Image.Resampling.BILINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||
return Image.Resampling.BICUBIC
|
||||
elif interpolation == "area":
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Area interpolation is related to cv2.INTER_AREA
|
||||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||
return Image.Resampling.HAMMING
|
||||
elif interpolation == "box":
|
||||
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
|
||||
return Image.Resampling.BOX
|
||||
else:
|
||||
return None
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
@@ -1,415 +0,0 @@
|
||||
# Minimum Inference Code for Lumina
|
||||
# Based on flux_minimal_inference.py
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma2Model
|
||||
from library.flux_models import AutoEncoder
|
||||
|
||||
from library import (
|
||||
device_utils,
|
||||
lumina_models,
|
||||
lumina_train_util,
|
||||
lumina_util,
|
||||
sd3_train_utils,
|
||||
strategy_lumina,
|
||||
)
|
||||
import networks.lora_lumina as lora_lumina
|
||||
from library.device_utils import get_preferred_device, init_ipex
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
|
||||
init_ipex()
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: lumina_models.NextDiT,
|
||||
gemma2: Gemma2Model,
|
||||
ae: AutoEncoder,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
seed: Optional[int],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
negative_prompt: Optional[str],
|
||||
args,
|
||||
cfg_trunc_ratio: float = 0.25,
|
||||
renorm_cfg: float = 1.0,
|
||||
):
|
||||
#
|
||||
# 0. Prepare arguments
|
||||
#
|
||||
device = get_preferred_device()
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
|
||||
dtype = str_to_dtype(args.dtype)
|
||||
ae_dtype = str_to_dtype(args.ae_dtype)
|
||||
gemma2_dtype = str_to_dtype(args.gemma2_dtype)
|
||||
|
||||
#
|
||||
# 1. Prepare models
|
||||
#
|
||||
# model.to(device, dtype=dtype)
|
||||
model.to(dtype)
|
||||
model.eval()
|
||||
|
||||
gemma2.to(device, dtype=gemma2_dtype)
|
||||
gemma2.eval()
|
||||
|
||||
ae.to(ae_dtype)
|
||||
ae.eval()
|
||||
|
||||
#
|
||||
# 2. Encode prompts
|
||||
#
|
||||
logger.info("Encoding prompts...")
|
||||
|
||||
tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length)
|
||||
encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
with torch.no_grad():
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
with torch.no_grad():
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
prompt_hidden_states, _, prompt_attention_mask = gemma2_conds
|
||||
uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds
|
||||
|
||||
if args.offload:
|
||||
print("Offloading models to CPU to save VRAM...")
|
||||
gemma2.to("cpu")
|
||||
device_utils.clean_memory()
|
||||
|
||||
model.to(device)
|
||||
|
||||
#
|
||||
# 3. Prepare latents
|
||||
#
|
||||
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Seed: {seed}")
|
||||
torch.manual_seed(seed)
|
||||
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
latent_channels = 16
|
||||
|
||||
latents = torch.randn(
|
||||
(1, latent_channels, latent_height, latent_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
#
|
||||
# 4. Denoise
|
||||
#
|
||||
logger.info("Denoising...")
|
||||
scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
scheduler.set_timesteps(steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
# # compare with lumina_train_util.retrieve_timesteps
|
||||
# lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps)
|
||||
# print(f"Using timesteps: {timesteps}")
|
||||
# print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad():
|
||||
latents = lumina_train_util.denoise(
|
||||
scheduler,
|
||||
model,
|
||||
latents.to(device),
|
||||
prompt_hidden_states.to(device),
|
||||
prompt_attention_mask.to(device),
|
||||
uncond_hidden_states.to(device),
|
||||
uncond_attention_mask.to(device),
|
||||
timesteps,
|
||||
guidance_scale,
|
||||
cfg_trunc_ratio,
|
||||
renorm_cfg,
|
||||
)
|
||||
|
||||
if args.offload:
|
||||
model.to("cpu")
|
||||
device_utils.clean_memory()
|
||||
ae.to(device)
|
||||
|
||||
#
|
||||
# 5. Decode latents
|
||||
#
|
||||
logger.info("Decoding image...")
|
||||
latents = latents / ae.scale_factor + ae.shift_factor
|
||||
with torch.no_grad():
|
||||
image = ae.decode(latents.to(ae_dtype))
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
|
||||
#
|
||||
# 6. Save image
|
||||
#
|
||||
pil_image = Image.fromarray(image[0])
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
seed_suffix = f"_{seed}"
|
||||
output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png")
|
||||
pil_image.save(output_path)
|
||||
logger.info(f"Image saved to {output_path}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Lumina DiT model path / Lumina DiTモデルのパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gemma2_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Gemma2 model path / Gemma2モデルのパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Autoencoder model path / Autoencoderモデルのパス",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation")
|
||||
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty")
|
||||
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
parser.add_argument("--steps", type=int, default=36, help="Number of inference steps")
|
||||
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance")
|
||||
parser.add_argument("--image_width", type=int, default=1024, help="Image width")
|
||||
parser.add_argument("--image_height", type=int, default=1024, help="Image height")
|
||||
parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)")
|
||||
parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)")
|
||||
parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)")
|
||||
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')")
|
||||
parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM")
|
||||
parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model")
|
||||
parser.add_argument(
|
||||
"--gemma2_max_token_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Max token length for Gemma2 tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=6.0,
|
||||
help="Shift value for FlowMatchEulerDiscreteScheduler",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_trunc_ratio",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="TBD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--renorm_cfg",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="TBD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
help="Use flash attention for Lumina model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_sage_attn",
|
||||
action="store_true",
|
||||
help="Use sage attention for Lumina model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Loading models...")
|
||||
device = get_preferred_device()
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# Load Lumina DiT model
|
||||
model = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
dtype=None, # Load in fp32 and then convert
|
||||
device="cpu",
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
use_sage_attn=args.use_sage_attn,
|
||||
)
|
||||
|
||||
# Load Gemma2
|
||||
gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu")
|
||||
|
||||
# Load Autoencoder
|
||||
ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")
|
||||
|
||||
# LoRA
|
||||
lora_models = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([gemma2], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([gemma2], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.to(device)
|
||||
lora_model.set_multiplier(multiplier)
|
||||
lora_model.eval()
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
args.prompt,
|
||||
args.system_prompt,
|
||||
args.seed,
|
||||
args.image_width,
|
||||
args.image_height,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
args.negative_prompt,
|
||||
args,
|
||||
args.cfg_trunc_ratio,
|
||||
args.renorm_cfg,
|
||||
)
|
||||
else:
|
||||
# Interactive mode loop
|
||||
image_width = args.image_width
|
||||
image_height = args.image_height
|
||||
steps = args.steps
|
||||
guidance_scale = args.guidance_scale
|
||||
cfg_trunc_ratio = args.cfg_trunc_ratio
|
||||
renorm_cfg = args.renorm_cfg
|
||||
|
||||
print("Entering interactive mode.")
|
||||
while True:
|
||||
print(
|
||||
"\nEnter prompt (or 'exit'). Options: --w <int> --h <int> --s <int> --d <int> --g <float> --n <str> --ctr <float> --rcfg <float> --m <m1,m2...>"
|
||||
)
|
||||
user_input = input()
|
||||
if user_input.lower() == "exit":
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Parse options
|
||||
options = user_input.split("--")
|
||||
prompt = options[0].strip()
|
||||
|
||||
# Set defaults for each generation
|
||||
seed = None # New random seed each time unless specified
|
||||
negative_prompt = args.negative_prompt # Reset to default
|
||||
|
||||
for opt in options[1:]:
|
||||
try:
|
||||
opt = opt.strip()
|
||||
if not opt:
|
||||
continue
|
||||
|
||||
key, value = (opt.split(None, 1) + [""])[:2]
|
||||
|
||||
if key == "w":
|
||||
image_width = int(value)
|
||||
elif key == "h":
|
||||
image_height = int(value)
|
||||
elif key == "s":
|
||||
steps = int(value)
|
||||
elif key == "d":
|
||||
seed = int(value)
|
||||
elif key == "g":
|
||||
guidance_scale = float(value)
|
||||
elif key == "n":
|
||||
negative_prompt = value if value != "-" else ""
|
||||
elif key == "ctr":
|
||||
cfg_trunc_ratio = float(value)
|
||||
elif key == "rcfg":
|
||||
renorm_cfg = float(value)
|
||||
elif key == "m":
|
||||
multipliers = value.split(",")
|
||||
if len(multipliers) != len(lora_models):
|
||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
continue
|
||||
for i, lora_model in enumerate(lora_models):
|
||||
lora_model.set_multiplier(float(multipliers[i].strip()))
|
||||
else:
|
||||
logger.warning(f"Unknown option: --{key}")
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")
|
||||
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
prompt,
|
||||
args.system_prompt,
|
||||
seed,
|
||||
image_width,
|
||||
image_height,
|
||||
steps,
|
||||
guidance_scale,
|
||||
negative_prompt,
|
||||
args,
|
||||
cfg_trunc_ratio,
|
||||
renorm_cfg,
|
||||
)
|
||||
|
||||
logger.info("Done.")
|
||||
953
lumina_train.py
953
lumina_train.py
@@ -1,953 +0,0 @@
|
||||
# training with captions
|
||||
|
||||
# Swap blocks between CPU and GPU:
|
||||
# This implementation is inspired by and based on the work of 2kpr.
|
||||
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
|
||||
# The original idea has been adapted and extended to fit the current project's needs.
|
||||
|
||||
# Key features:
|
||||
# - CPU offloading during forward and backward passes
|
||||
# - Use of fused optimizer and grad_hook for efficient gradient processing
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import (
|
||||
deepspeed_utils,
|
||||
lumina_train_util,
|
||||
lumina_util,
|
||||
strategy_base,
|
||||
strategy_lumina,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
# import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# temporary: backward compatibility for deprecated options. remove in the future
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning(
|
||||
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
|
||||
)
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
# assert (
|
||||
# args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
# ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(
|
||||
ConfigSanitizer(True, True, args.masked_loss, True)
|
||||
)
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = (
|
||||
config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = (
|
||||
train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
)
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(
|
||||
strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
|
||||
)
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = lumina_util.load_ae(
|
||||
args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare tokenize strategy
|
||||
if args.gemma2_max_token_length is None:
|
||||
gemma2_max_token_length = 256
|
||||
else:
|
||||
gemma2_max_token_length = args.gemma2_max_token_length
|
||||
|
||||
lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
|
||||
args.system_prompt, gemma2_max_token_length
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)
|
||||
|
||||
# load gemma2 for caching text encoder outputs
|
||||
gemma2 = lumina_util.load_gemma2(
|
||||
args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
gemma2.eval()
|
||||
gemma2.requires_grad_(False)
|
||||
|
||||
text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad here
|
||||
gemma2.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = (
|
||||
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
text_encoder_caching_strategy
|
||||
)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
|
||||
)
|
||||
|
||||
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
|
||||
strategy_base.TextEncodingStrategy.get_strategy()
|
||||
)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for i, p in enumerate([
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]):
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
|
||||
sample_prompts_te_outputs[p] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
lumina_tokenize_strategy,
|
||||
[gemma2],
|
||||
tokens_and_masks,
|
||||
)
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# now we can delete Text Encoders to free memory
|
||||
gemma2 = None
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load lumina
|
||||
nextdit = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
nextdit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing
|
||||
)
|
||||
|
||||
nextdit.requires_grad_(True)
|
||||
|
||||
# block swap
|
||||
|
||||
# backward compatibility
|
||||
# if args.blocks_to_swap is None:
|
||||
# blocks_to_swap = args.double_blocks_to_swap or 0
|
||||
# if args.single_blocks_to_swap is not None:
|
||||
# blocks_to_swap += args.single_blocks_to_swap // 2
|
||||
# if blocks_to_swap > 0:
|
||||
# logger.warning(
|
||||
# "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
||||
# " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
||||
# )
|
||||
# logger.info(
|
||||
# f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
||||
# )
|
||||
# args.blocks_to_swap = blocks_to_swap
|
||||
# del blocks_to_swap
|
||||
|
||||
# is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
# if is_swapping_blocks:
|
||||
# # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# # This idea is based on 2kpr's great work. Thank you!
|
||||
# logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
# flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(nextdit)
|
||||
name_and_params = list(nextdit.named_parameters())
|
||||
# single param group for now
|
||||
params_to_optimize.append(
|
||||
{"params": [p for _, p in name_and_params], "lr": args.learning_rate}
|
||||
)
|
||||
param_names = [[n for n, _ in name_and_params]]
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for group in params_to_optimize:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(
|
||||
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||
)
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError(
|
||||
"Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||
)
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
args, trainable_params=params_to_optimize
|
||||
)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
optimizer, args
|
||||
)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(
|
||||
args.max_data_loader_n_workers, os.cpu_count()
|
||||
) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader)
|
||||
/ accelerator.num_processes
|
||||
/ args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
if args.blockwise_fused_optimizers:
|
||||
# prepare lr schedulers for each optimizer
|
||||
lr_schedulers = [
|
||||
train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
for optimizer in optimizers
|
||||
]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args, optimizer, accelerator.num_processes
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
nextdit.to(weight_dtype)
|
||||
if gemma2 is not None:
|
||||
gemma2.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
nextdit.to(weight_dtype)
|
||||
if gemma2 is not None:
|
||||
gemma2.to(weight_dtype)
|
||||
|
||||
# if we don't cache text encoder outputs, move them to device
|
||||
if not args.cache_text_encoder_outputs:
|
||||
gemma2.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
# accelerator does some magic
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit = accelerator.prepare(
|
||||
nextdit, device_placement=[not is_swapping_blocks]
|
||||
)
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
|
||||
accelerator.device
|
||||
) # reduce peak memory usage
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(
|
||||
create_grad_hook(param_name, param_group)
|
||||
)
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# counters are used to determine when to step the optimizer
|
||||
global optimizer_hooked_count
|
||||
global num_parameters_per_group
|
||||
global parameter_optimizer_map
|
||||
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(
|
||||
parameter, args.max_grad_norm
|
||||
)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = (
|
||||
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
)
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(
|
||||
f" num examples / サンプル数: {train_dataset_group.num_train_images}"
|
||||
)
|
||||
accelerator.print(
|
||||
f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}"
|
||||
)
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(
|
||||
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
accelerator.print(
|
||||
f" total optimization steps / 学習ステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="steps",
|
||||
)
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000, shift=args.discrete_flow_shift
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0 # avoid error when max_train_steps is 0
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {
|
||||
i: 0 for i in range(len(optimizers))
|
||||
} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
else:
|
||||
# not cached or training, so get from text encoders
|
||||
tokens_and_masks = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids = [
|
||||
ids.to(accelerator.device)
|
||||
for ids in batch["input_ids_list"]
|
||||
]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
lumina_tokenize_strategy,
|
||||
[gemma2],
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [
|
||||
c.to(weight_dtype) for c in text_encoder_conds
|
||||
]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = (
|
||||
lumina_train_util.get_noisy_model_input_and_timesteps(
|
||||
args,
|
||||
noise_scheduler_copy,
|
||||
latents,
|
||||
noise,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
)
|
||||
# call model
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = nextdit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
dtype=torch.int32
|
||||
), # Gemma2的attention mask
|
||||
)
|
||||
# apply model prediction type
|
||||
model_pred, weighting = lumina_train_util.apply_model_prediction_type(
|
||||
args, model_pred, noisy_model_input, sigmas
|
||||
)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(
|
||||
args, timesteps, noise_scheduler
|
||||
)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or (
|
||||
"alpha_masks" in batch and batch["alpha_masks"] is not None
|
||||
):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if (
|
||||
args.save_every_n_steps is not None
|
||||
and global_step % args.save_every_n_steps == 0
|
||||
):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(nextdit),
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(
|
||||
logs, lr_scheduler, args.optimizer_type, including_unet=True
|
||||
)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(nextdit),
|
||||
)
|
||||
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
nextdit,
|
||||
ae,
|
||||
gemma2,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
nextdit = accelerator.unwrap_model(nextdit)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
lumina_train_util.save_lumina_model_on_train_end(
|
||||
args, save_dtype, epoch, global_step, nextdit
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser) # TODO remove this from here
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
lumina_train_util.add_lumina_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--mem_eff_save",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fused_optimizer_groups",
|
||||
type=int,
|
||||
default=None,
|
||||
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
@@ -1,383 +0,0 @@
|
||||
import argparse
|
||||
import copy
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch import Tensor
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
lumina_models,
|
||||
lumina_util,
|
||||
lumina_train_util,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_lumina,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
self.train_gemma2 = not args.network_train_unet_only
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
model = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
use_sage_attn=args.use_sage_attn,
|
||||
)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 Lumina 2 model")
|
||||
else:
|
||||
logger.info(
|
||||
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||||
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||||
)
|
||||
model.to(torch.float8_e4m3fn)
|
||||
|
||||
if args.blocks_to_swap:
|
||||
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
self.is_swapping_blocks = True
|
||||
|
||||
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
|
||||
gemma2.eval()
|
||||
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
||||
|
||||
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_lumina.LuminaTextEncodingStrategy()
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [self.train_gemma2]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_gemma2,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self,
|
||||
args,
|
||||
accelerator: Accelerator,
|
||||
unet,
|
||||
vae,
|
||||
text_encoders,
|
||||
dataset,
|
||||
weight_dtype,
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in sample_prompts:
|
||||
prompts = [
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt in sample_prompts_te_outputs:
|
||||
continue
|
||||
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
|
||||
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
)
|
||||
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move back to cpu
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move Gemma 2 back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def sample_images(
|
||||
self,
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
lumina,
|
||||
):
|
||||
lumina_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
lumina,
|
||||
vae,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoder),
|
||||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# Remaining methods maintain similar structure to flux implementation
|
||||
# with Lumina-specific model calls and strategies
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
# not sure, they use same flux vae
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
|
||||
dit: lumina_models.NextDiT,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
|
||||
noise = torch.randn_like(latents)
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
||||
|
||||
def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = dit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=noisy_model_input,
|
||||
gemma2_hidden_states=gemma2_hidden_states,
|
||||
gemma2_attn_mask=gemma2_attn_mask,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss
|
||||
target = latents - noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=noisy_model_input[diff_output_pr_indices],
|
||||
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
|
||||
)
|
||||
network.set_multiplier(1.0)
|
||||
|
||||
# model_pred_prior = lumina_util.unpack_latents(
|
||||
# model_pred_prior, packed_latent_height, packed_latent_width
|
||||
# )
|
||||
model_pred_prior, _ = lumina_train_util.apply_model_prediction_type(
|
||||
args,
|
||||
model_pred_prior,
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
text_encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.embed_tokens.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit = unet
|
||||
assert isinstance(nextdit, lumina_models.NextDiT)
|
||||
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
||||
|
||||
return nextdit
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
lumina_train_util.add_lumina_train_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = LuminaNetworkTrainer()
|
||||
trainer.train(args)
|
||||
@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -9,13 +9,11 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
@@ -46,8 +44,6 @@ class LoRAModule(torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -107,20 +103,9 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
@@ -155,17 +140,7 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
else:
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
|
||||
@@ -192,116 +167,6 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_norm_cache(self, org_module_weight: Tensor):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
# Not running GGPO so not currently running update norms
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
module_weights = self.lora_up.weight @ self.lora_down.weight
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
if self.training is False:
|
||||
print(f"skipping update_grad_norms for {self.lora_name}")
|
||||
return
|
||||
|
||||
lora_down_grad = None
|
||||
lora_up_grad = None
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if name == "lora_down.weight":
|
||||
lora_down_grad = param.grad
|
||||
elif name == "lora_up.weight":
|
||||
lora_up_grad = param.grad
|
||||
|
||||
# Calculate gradient norms if we have both gradients
|
||||
if lora_down_grad is not None and lora_up_grad is not None:
|
||||
with torch.autocast(self.device.type):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
@@ -555,16 +420,6 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
ggpo_beta = kwargs.get("ggpo_beta", None)
|
||||
ggpo_sigma = kwargs.get("ggpo_sigma", None)
|
||||
|
||||
if ggpo_beta is not None:
|
||||
ggpo_beta = float(ggpo_beta)
|
||||
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -594,8 +449,6 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -708,8 +561,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -748,16 +599,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -877,8 +722,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
@@ -947,36 +790,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def update_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_norms()
|
||||
|
||||
def update_grad_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,4 +6,3 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -7,11 +7,9 @@ opencv-python==4.8.1.78
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.44.0
|
||||
prodigyopt==1.0
|
||||
lion-pytorch==0.0.6
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.5.0
|
||||
prodigy-plus-schedule-free==1.9.0
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
safetensors==0.4.4
|
||||
# gradio==3.16.2
|
||||
|
||||
14
sd3_train.py
14
sd3_train.py
@@ -75,6 +75,12 @@ def train(args):
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
|
||||
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
|
||||
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
|
||||
)
|
||||
|
||||
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
|
||||
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
|
||||
+ " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)"
|
||||
@@ -169,8 +175,8 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
args.t5xxl_max_token_length,
|
||||
args.apply_lg_attn_mask,
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -279,8 +285,8 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial
|
||||
args.t5xxl_max_token_length,
|
||||
args.apply_lg_attn_mask,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
@@ -331,7 +337,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -26,12 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -48,6 +43,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
|
||||
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
|
||||
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
|
||||
)
|
||||
|
||||
# prepare CLIP-L/CLIP-G/T5XXL training flags
|
||||
self.train_clip = not args.network_train_unet_only
|
||||
@@ -193,8 +192,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip or self.train_t5xxl,
|
||||
max_token_length=args.t5xxl_max_token_length,
|
||||
apply_lg_attn_mask=args.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -304,7 +303,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
@@ -322,7 +321,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -450,19 +449,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
# drop cached text encoder outputs
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
|
||||
@@ -273,7 +273,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -322,7 +322,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -223,7 +223,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -195,7 +195,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -82,7 +83,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
)
|
||||
|
||||
# Mock classes and functions
|
||||
class MockNoiseScheduler:
|
||||
def __init__(self, num_train_timesteps=1000):
|
||||
self.config = MagicMock()
|
||||
self.config.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
|
||||
|
||||
|
||||
# Create fixtures for commonly used objects
|
||||
@pytest.fixture
|
||||
def args():
|
||||
args = MagicMock()
|
||||
args.timestep_sampling = "uniform"
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
args.sigmoid_scale = 1.0
|
||||
args.discrete_flow_shift = 3.1582
|
||||
args.ip_noise_gamma = None
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
return MockNoiseScheduler(num_train_timesteps=1000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def latents():
|
||||
return torch.randn(2, 4, 8, 8)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise():
|
||||
return torch.randn(2, 4, 8, 8)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
# return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
return "cpu"
|
||||
|
||||
|
||||
# Mock the required functions
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_functions():
|
||||
with (
|
||||
patch("torch.sigmoid", side_effect=torch.sigmoid),
|
||||
patch("torch.rand", side_effect=torch.rand),
|
||||
patch("torch.randn", side_effect=torch.randn),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# Test different timestep sampling methods
|
||||
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "uniform"
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
|
||||
|
||||
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "sigmoid"
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "shift"
|
||||
args.sigmoid_scale = 1.0
|
||||
args.discrete_flow_shift = 3.1582
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "flux_shift"
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
# Mock the necessary functions for this specific test
|
||||
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
|
||||
return_value=torch.tensor([0.3, 0.7], device=device)), \
|
||||
patch("library.flux_train_utils.get_sigmas",
|
||||
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
|
||||
|
||||
args.timestep_sampling = "other" # Will trigger the weighting scheme path
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test IP noise options
|
||||
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma = 0.5
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma = 0.1
|
||||
args.ip_noise_gamma_random_strength = True
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test different data types
|
||||
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
|
||||
dtype = torch.float16
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
|
||||
|
||||
# Test different batch sizes
|
||||
def test_different_batch_size(args, noise_scheduler, device):
|
||||
latents = torch.randn(5, 4, 8, 8) # batch size of 5
|
||||
noise = torch.randn(5, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (5,)
|
||||
assert sigmas.shape == (5, 1, 1, 1)
|
||||
|
||||
|
||||
# Test different image sizes
|
||||
def test_different_image_size(args, noise_scheduler, device):
|
||||
latents = torch.randn(2, 4, 16, 16) # larger image size
|
||||
noise = torch.randn(2, 4, 16, 16)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert sigmas.shape == (2, 1, 1, 1)
|
||||
|
||||
|
||||
# Test edge cases
|
||||
def test_zero_batch_size(args, noise_scheduler, device):
|
||||
with pytest.raises(AssertionError): # expecting an error with zero batch size
|
||||
latents = torch.randn(0, 4, 8, 8)
|
||||
noise = torch.randn(0, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
|
||||
def test_different_timestep_count(args, device):
|
||||
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
|
||||
latents = torch.randn(2, 4, 8, 8)
|
||||
noise = torch.randn(2, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
@@ -1,295 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.lumina_models import (
|
||||
LuminaParams,
|
||||
to_cuda,
|
||||
to_cpu,
|
||||
RopeEmbedder,
|
||||
TimestepEmbedder,
|
||||
modulate,
|
||||
NextDiT,
|
||||
)
|
||||
|
||||
cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
|
||||
def test_lumina_params():
|
||||
# Test default configuration
|
||||
default_params = LuminaParams()
|
||||
assert default_params.patch_size == 2
|
||||
assert default_params.in_channels == 4
|
||||
assert default_params.axes_dims == [36, 36, 36]
|
||||
assert default_params.axes_lens == [300, 512, 512]
|
||||
|
||||
# Test 2B config
|
||||
config_2b = LuminaParams.get_2b_config()
|
||||
assert config_2b.dim == 2304
|
||||
assert config_2b.in_channels == 16
|
||||
assert config_2b.n_layers == 26
|
||||
assert config_2b.n_heads == 24
|
||||
assert config_2b.cap_feat_dim == 2304
|
||||
|
||||
# Test 7B config
|
||||
config_7b = LuminaParams.get_7b_config()
|
||||
assert config_7b.dim == 4096
|
||||
assert config_7b.n_layers == 32
|
||||
assert config_7b.n_heads == 32
|
||||
assert config_7b.axes_dims == [64, 64, 64]
|
||||
|
||||
|
||||
@cuda_required
|
||||
def test_to_cuda_to_cpu():
|
||||
# Test tensor conversion
|
||||
x = torch.tensor([1, 2, 3])
|
||||
x_cuda = to_cuda(x)
|
||||
x_cpu = to_cpu(x_cuda)
|
||||
assert x.cpu().tolist() == x_cpu.tolist()
|
||||
|
||||
# Test list conversion
|
||||
list_data = [torch.tensor([1]), torch.tensor([2])]
|
||||
list_cuda = to_cuda(list_data)
|
||||
assert all(tensor.device.type == "cuda" for tensor in list_cuda)
|
||||
|
||||
list_cpu = to_cpu(list_cuda)
|
||||
assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
|
||||
|
||||
# Test dict conversion
|
||||
dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
|
||||
dict_cuda = to_cuda(dict_data)
|
||||
assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
|
||||
|
||||
dict_cpu = to_cpu(dict_cuda)
|
||||
assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
|
||||
|
||||
|
||||
def test_timestep_embedder():
|
||||
# Test initialization
|
||||
hidden_size = 256
|
||||
freq_emb_size = 128
|
||||
embedder = TimestepEmbedder(hidden_size, freq_emb_size)
|
||||
assert embedder.frequency_embedding_size == freq_emb_size
|
||||
|
||||
# Test timestep embedding
|
||||
t = torch.tensor([0.5, 1.0, 2.0])
|
||||
emb_dim = freq_emb_size
|
||||
embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
|
||||
|
||||
assert embeddings.shape == (3, emb_dim)
|
||||
assert embeddings.dtype == torch.float32
|
||||
|
||||
# Ensure embeddings are unique for different input times
|
||||
assert not torch.allclose(embeddings[0], embeddings[1])
|
||||
|
||||
# Test forward pass
|
||||
t_emb = embedder(t)
|
||||
assert t_emb.shape == (3, hidden_size)
|
||||
|
||||
|
||||
def test_rope_embedder_simple():
|
||||
rope_embedder = RopeEmbedder()
|
||||
batch_size, seq_len = 2, 10
|
||||
|
||||
# Create position_ids with valid ranges for each axis
|
||||
position_ids = torch.stack(
|
||||
[
|
||||
torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
|
||||
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
|
||||
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
freqs_cis = rope_embedder(position_ids)
|
||||
# RoPE embeddings work in pairs, so output dimension is half of total axes_dims
|
||||
expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
|
||||
assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
|
||||
|
||||
|
||||
def test_modulate():
|
||||
# Test modulation with different scales
|
||||
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
scale = torch.tensor([1.5, 2.0])
|
||||
|
||||
modulated_x = modulate(x, scale)
|
||||
|
||||
# Check that modulation scales correctly
|
||||
# The function does x * (1 + scale), so:
|
||||
# For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
|
||||
expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
|
||||
# Which equals: [[2.5, 5.0], [9.0, 12.0]]
|
||||
|
||||
assert torch.allclose(modulated_x, expected_x)
|
||||
|
||||
|
||||
def test_nextdit_parameter_count_optimized():
|
||||
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
||||
# So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
||||
model_small = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4, # Smaller
|
||||
dim=120, # 120 // 4 = 30
|
||||
n_layers=2, # Much fewer layers
|
||||
n_heads=4, # Fewer heads
|
||||
n_kv_heads=2,
|
||||
axes_dims=[10, 10, 10], # sum = 30
|
||||
axes_lens=[10, 32, 32], # Smaller
|
||||
)
|
||||
param_count_small = model_small.parameter_count()
|
||||
assert param_count_small > 0
|
||||
|
||||
# For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
|
||||
model_medium = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=192, # 192 // 6 = 32
|
||||
n_layers=4, # More layers
|
||||
n_heads=6,
|
||||
n_kv_heads=3,
|
||||
axes_dims=[10, 11, 11], # sum = 32
|
||||
axes_lens=[10, 32, 32],
|
||||
)
|
||||
param_count_medium = model_medium.parameter_count()
|
||||
assert param_count_medium > param_count_small
|
||||
print(f"Small model: {param_count_small:,} parameters")
|
||||
print(f"Medium model: {param_count_medium:,} parameters")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_precompute_freqs_cis():
|
||||
# Test precompute_freqs_cis
|
||||
dim = [16, 56, 56]
|
||||
end = [1, 512, 512]
|
||||
theta = 10000.0
|
||||
|
||||
freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
|
||||
|
||||
# Check number of frequency tensors
|
||||
assert len(freqs_cis) == len(dim)
|
||||
|
||||
# Check each frequency tensor
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
assert freqs_cis[i].shape == (e, d // 2)
|
||||
assert freqs_cis[i].dtype == torch.complex128
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_nextdit_patchify_and_embed():
|
||||
"""Test the patchify_and_embed method which is crucial for training"""
|
||||
# Create a small NextDiT model for testing
|
||||
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
||||
# For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
||||
model = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=120, # 120 // 4 = 30
|
||||
n_layers=1, # Minimal layers for faster testing
|
||||
n_refiner_layers=1, # Minimal refiner layers
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
axes_dims=[10, 10, 10], # sum = 30
|
||||
axes_lens=[10, 32, 32],
|
||||
cap_feat_dim=120, # Match dim for consistency
|
||||
)
|
||||
|
||||
# Prepare test inputs
|
||||
batch_size = 2
|
||||
height, width = 64, 64 # Must be divisible by patch_size (2)
|
||||
caption_seq_len = 8
|
||||
|
||||
# Create mock inputs
|
||||
x = torch.randn(batch_size, 4, height, width) # Image latents
|
||||
cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
|
||||
cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
|
||||
# Make second batch have shorter caption
|
||||
cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
|
||||
t = torch.randn(batch_size, 120) # Timestep embeddings
|
||||
|
||||
# Call patchify_and_embed
|
||||
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
image_seq_len = (height // 2) * (width // 2) # patch_size = 2
|
||||
expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
|
||||
max_seq_len = max(expected_seq_lengths)
|
||||
|
||||
# Check joint hidden states shape
|
||||
assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
|
||||
assert joint_hidden_states.dtype == torch.float32
|
||||
|
||||
# Check attention mask shape and values
|
||||
assert attention_mask.shape == (batch_size, max_seq_len)
|
||||
assert attention_mask.dtype == torch.bool
|
||||
# First batch should have all positions valid up to its sequence length
|
||||
assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
|
||||
assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
|
||||
# Second batch should have all positions valid up to its sequence length
|
||||
assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
|
||||
assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
|
||||
|
||||
# Check freqs_cis shape
|
||||
assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
|
||||
|
||||
# Check effective caption lengths
|
||||
assert l_effective_cap_len == [caption_seq_len, 6]
|
||||
|
||||
# Check sequence lengths
|
||||
assert seq_lengths == expected_seq_lengths
|
||||
|
||||
# Validate that the joint hidden states contain non-zero values where attention mask is True
|
||||
for i in range(batch_size):
|
||||
valid_positions = attention_mask[i]
|
||||
# Check that valid positions have meaningful data (not all zeros)
|
||||
valid_data = joint_hidden_states[i][valid_positions]
|
||||
assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
|
||||
|
||||
# Check that invalid positions are zeros
|
||||
if valid_positions.sum() < max_seq_len:
|
||||
invalid_data = joint_hidden_states[i][~valid_positions]
|
||||
assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_nextdit_patchify_and_embed_edge_cases():
|
||||
"""Test edge cases for patchify_and_embed"""
|
||||
# Create minimal model
|
||||
model = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=60, # 60 // 3 = 20
|
||||
n_layers=1,
|
||||
n_refiner_layers=1,
|
||||
n_heads=3,
|
||||
n_kv_heads=1,
|
||||
axes_dims=[8, 6, 6], # sum = 20
|
||||
axes_lens=[10, 16, 16],
|
||||
cap_feat_dim=60,
|
||||
)
|
||||
|
||||
# Test with empty captions (all masked)
|
||||
batch_size = 1
|
||||
height, width = 32, 32
|
||||
caption_seq_len = 4
|
||||
|
||||
x = torch.randn(batch_size, 4, height, width)
|
||||
cap_feats = torch.randn(batch_size, caption_seq_len, 60)
|
||||
cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
|
||||
t = torch.randn(batch_size, 60)
|
||||
|
||||
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
|
||||
# With all captions masked, effective length should be 0
|
||||
assert l_effective_cap_len == [0]
|
||||
|
||||
# Sequence length should just be the image sequence length
|
||||
image_seq_len = (height // 2) * (width // 2)
|
||||
assert seq_lengths == [image_seq_len]
|
||||
|
||||
# Joint hidden states should only contain image data
|
||||
assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
|
||||
assert attention_mask.shape == (batch_size, image_seq_len)
|
||||
assert torch.all(attention_mask[0]) # All image positions should be valid
|
||||
@@ -1,241 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
from library.lumina_train_util import (
|
||||
batchify,
|
||||
time_shift,
|
||||
get_lin_function,
|
||||
get_schedule,
|
||||
compute_density_for_timestep_sampling,
|
||||
get_sigmas,
|
||||
compute_loss_weighting_for_sd3,
|
||||
get_noisy_model_input_and_timesteps,
|
||||
apply_model_prediction_type,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
|
||||
def test_batchify():
|
||||
# Test case with no batch size specified
|
||||
prompts = [
|
||||
{"prompt": "test1"},
|
||||
{"prompt": "test2"},
|
||||
{"prompt": "test3"}
|
||||
]
|
||||
batchified = list(batchify(prompts))
|
||||
assert len(batchified) == 1
|
||||
assert len(batchified[0]) == 3
|
||||
|
||||
# Test case with batch size specified
|
||||
batchified_sized = list(batchify(prompts, batch_size=2))
|
||||
assert len(batchified_sized) == 2
|
||||
assert len(batchified_sized[0]) == 2
|
||||
assert len(batchified_sized[1]) == 1
|
||||
|
||||
# Test batching with prompts having same parameters
|
||||
prompts_with_params = [
|
||||
{"prompt": "test1", "width": 512, "height": 512},
|
||||
{"prompt": "test2", "width": 512, "height": 512},
|
||||
{"prompt": "test3", "width": 1024, "height": 1024}
|
||||
]
|
||||
batchified_params = list(batchify(prompts_with_params))
|
||||
assert len(batchified_params) == 2
|
||||
|
||||
# Test invalid batch size
|
||||
with pytest.raises(ValueError):
|
||||
list(batchify(prompts, batch_size=0))
|
||||
with pytest.raises(ValueError):
|
||||
list(batchify(prompts, batch_size=-1))
|
||||
|
||||
|
||||
def test_time_shift():
|
||||
# Test standard parameters
|
||||
t = torch.tensor([0.5])
|
||||
mu = 1.0
|
||||
sigma = 1.0
|
||||
result = time_shift(mu, sigma, t)
|
||||
assert 0 <= result <= 1
|
||||
|
||||
# Test with edge cases
|
||||
t_edges = torch.tensor([0.0, 1.0])
|
||||
result_edges = time_shift(1.0, 1.0, t_edges)
|
||||
|
||||
# Check that results are bounded within [0, 1]
|
||||
assert torch.all(result_edges >= 0)
|
||||
assert torch.all(result_edges <= 1)
|
||||
|
||||
|
||||
def test_get_lin_function():
|
||||
# Default parameters
|
||||
func = get_lin_function()
|
||||
assert func(256) == 0.5
|
||||
assert func(4096) == 1.15
|
||||
|
||||
# Custom parameters
|
||||
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
|
||||
assert custom_func(100) == 0.1
|
||||
assert custom_func(1000) == 0.9
|
||||
|
||||
|
||||
def test_get_schedule():
|
||||
# Basic schedule
|
||||
schedule = get_schedule(num_steps=10, image_seq_len=256)
|
||||
assert len(schedule) == 10
|
||||
assert all(0 <= x <= 1 for x in schedule)
|
||||
|
||||
# Test different sequence lengths
|
||||
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
|
||||
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
|
||||
assert len(short_schedule) == 5
|
||||
assert len(long_schedule) == 15
|
||||
|
||||
# Test with shift disabled
|
||||
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
|
||||
assert torch.allclose(
|
||||
torch.tensor(unshifted_schedule),
|
||||
torch.linspace(1, 1/10, 10)
|
||||
)
|
||||
|
||||
|
||||
def test_compute_density_for_timestep_sampling():
|
||||
# Test uniform sampling
|
||||
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
|
||||
assert len(uniform_samples) == 100
|
||||
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
|
||||
|
||||
# Test logit normal sampling
|
||||
logit_normal_samples = compute_density_for_timestep_sampling(
|
||||
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
|
||||
)
|
||||
assert len(logit_normal_samples) == 100
|
||||
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
|
||||
|
||||
# Test mode sampling
|
||||
mode_samples = compute_density_for_timestep_sampling(
|
||||
"mode", batch_size=100, mode_scale=0.5
|
||||
)
|
||||
assert len(mode_samples) == 100
|
||||
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
|
||||
|
||||
|
||||
def test_get_sigmas():
|
||||
# Create a mock noise scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Test with default parameters
|
||||
timesteps = torch.tensor([100, 500, 900])
|
||||
sigmas = get_sigmas(scheduler, timesteps, device)
|
||||
|
||||
# Check shape and basic properties
|
||||
assert sigmas.shape[0] == 3
|
||||
assert torch.all(sigmas >= 0)
|
||||
|
||||
# Test with different n_dim
|
||||
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
|
||||
assert sigmas_4d.ndim == 4
|
||||
|
||||
# Test with different dtype
|
||||
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
|
||||
assert sigmas_float16.dtype == torch.float16
|
||||
|
||||
|
||||
def test_compute_loss_weighting_for_sd3():
|
||||
# Prepare some mock sigmas
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
# Test sigma_sqrt weighting
|
||||
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
|
||||
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
|
||||
|
||||
# Test cosmap weighting
|
||||
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
expected_cosmap = 2 / (math.pi * bot)
|
||||
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
|
||||
|
||||
# Test default weighting
|
||||
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
|
||||
assert torch.all(default_weighting == 1)
|
||||
|
||||
|
||||
def test_apply_model_prediction_type():
|
||||
# Create mock args and tensors
|
||||
class MockArgs:
|
||||
model_prediction_type = "raw"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
|
||||
args = MockArgs()
|
||||
model_pred = torch.tensor([1.0, 2.0, 3.0])
|
||||
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
# Test raw prediction type
|
||||
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(raw_pred == model_pred)
|
||||
assert raw_weighting is None
|
||||
|
||||
# Test additive prediction type
|
||||
args.model_prediction_type = "additive"
|
||||
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(additive_pred == model_pred + noisy_model_input)
|
||||
|
||||
# Test sigma scaled prediction type
|
||||
args.model_prediction_type = "sigma_scaled"
|
||||
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
|
||||
assert sigma_weighting is not None
|
||||
|
||||
|
||||
def test_retrieve_timesteps():
|
||||
# Create a mock scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
|
||||
# Test with num_inference_steps
|
||||
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
|
||||
assert len(timesteps) == 50
|
||||
assert n_steps == 50
|
||||
|
||||
# Test error handling with simultaneous timesteps and sigmas
|
||||
with pytest.raises(ValueError):
|
||||
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
def test_get_noisy_model_input_and_timesteps():
|
||||
# Create a mock args and setup
|
||||
class MockArgs:
|
||||
timestep_sampling = "uniform"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
sigmoid_scale = 1.0
|
||||
discrete_flow_shift = 6.0
|
||||
|
||||
args = MockArgs()
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Prepare mock latents and noise
|
||||
latents = torch.randn(4, 16, 64, 64)
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Test uniform sampling
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
|
||||
# Validate output shapes and types
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
assert noisy_input.dtype == torch.float32
|
||||
assert timesteps.dtype == torch.float32
|
||||
|
||||
# Test different sampling methods
|
||||
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
|
||||
for method in sampling_methods:
|
||||
args.timestep_sampling = method
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
@@ -1,112 +0,0 @@
|
||||
import torch
|
||||
from torch.nn.modules import conv
|
||||
|
||||
from library import lumina_util
|
||||
|
||||
|
||||
def test_unpack_latents():
|
||||
# Create a test tensor
|
||||
# Shape: [batch, height*width, channels*patch_height*patch_width]
|
||||
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
|
||||
packed_latent_height = 2
|
||||
packed_latent_width = 2
|
||||
|
||||
# Unpack the latents
|
||||
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# Check output shape
|
||||
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
|
||||
assert unpacked.shape == (2, 4, 4, 4)
|
||||
|
||||
|
||||
def test_pack_latents():
|
||||
# Create a test tensor
|
||||
# Shape: [batch, channels, height*patch_height, width*patch_width]
|
||||
x = torch.randn(2, 4, 4, 4)
|
||||
|
||||
# Pack the latents
|
||||
packed = lumina_util.pack_latents(x)
|
||||
|
||||
# Check output shape
|
||||
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
|
||||
assert packed.shape == (2, 4, 16)
|
||||
|
||||
|
||||
def test_convert_diffusers_sd_to_alpha_vllm():
|
||||
num_double_blocks = 2
|
||||
# Predefined test cases based on the actual conversion map
|
||||
test_cases = [
|
||||
# Static key conversions with possible list mappings
|
||||
{
|
||||
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"expected_converted_keys": ["cap_embedder.0.weight"],
|
||||
},
|
||||
{
|
||||
"original_keys": ["patch_embedder.proj.weight"],
|
||||
"original_pattern": ["patch_embedder.proj.weight"],
|
||||
"expected_converted_keys": ["x_embedder.weight"],
|
||||
},
|
||||
{
|
||||
"original_keys": ["transformer_blocks.0.norm1.weight"],
|
||||
"original_pattern": ["transformer_blocks.().norm1.weight"],
|
||||
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for test_case in test_cases:
|
||||
for original_key, original_pattern, expected_converted_key in zip(
|
||||
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
|
||||
):
|
||||
# Create test state dict
|
||||
test_sd = {original_key: torch.randn(10, 10)}
|
||||
|
||||
# Convert the state dict
|
||||
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
||||
|
||||
# Verify conversion (handle both string and list keys)
|
||||
# Find the correct converted key
|
||||
match_found = False
|
||||
if expected_converted_key in converted_sd:
|
||||
# Verify tensor preservation
|
||||
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
|
||||
f"Tensor mismatch for {original_key}"
|
||||
)
|
||||
match_found = True
|
||||
break
|
||||
|
||||
assert match_found, f"Failed to convert {original_key}"
|
||||
|
||||
# Ensure original key is also present
|
||||
assert original_key in converted_sd
|
||||
|
||||
# Test with block-specific keys
|
||||
block_specific_cases = [
|
||||
{
|
||||
"original_pattern": "transformer_blocks.().norm1.weight",
|
||||
"converted_pattern": "layers.().attention_norm1.weight",
|
||||
}
|
||||
]
|
||||
|
||||
for case in block_specific_cases:
|
||||
for block_idx in range(2): # Test multiple block indices
|
||||
# Prepare block-specific keys
|
||||
block_original_key = case["original_pattern"].replace("()", str(block_idx))
|
||||
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
|
||||
print(block_original_key, block_converted_key)
|
||||
|
||||
# Create test state dict
|
||||
test_sd = {block_original_key: torch.randn(10, 10)}
|
||||
|
||||
# Convert the state dict
|
||||
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
||||
|
||||
# Verify conversion
|
||||
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
|
||||
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
|
||||
f"Tensor mismatch for block key {block_original_key}"
|
||||
)
|
||||
|
||||
# Ensure original key is also present
|
||||
assert block_original_key in converted_sd
|
||||
@@ -1,241 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import patch
|
||||
from transformers import Gemma2Model
|
||||
|
||||
from library.strategy_lumina import (
|
||||
LuminaTokenizeStrategy,
|
||||
LuminaTextEncodingStrategy,
|
||||
LuminaTextEncoderOutputsCachingStrategy,
|
||||
LuminaLatentsCachingStrategy,
|
||||
)
|
||||
|
||||
|
||||
class SimpleMockGemma2Model:
|
||||
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
|
||||
|
||||
def __init__(self, hidden_size=2304):
|
||||
self.device = torch.device("cpu")
|
||||
self._hidden_size = hidden_size
|
||||
self._orig_mod = self # For dynamic compilation compatibility
|
||||
|
||||
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
|
||||
# Create a mock output object with hidden states
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self._hidden_size
|
||||
|
||||
class MockOutput:
|
||||
def __init__(self, hidden_states):
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
mock_hidden_states = [
|
||||
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
|
||||
for _ in range(3) # Mimic multiple layers of hidden states
|
||||
]
|
||||
|
||||
return MockOutput(mock_hidden_states)
|
||||
|
||||
|
||||
def test_lumina_tokenize_strategy():
|
||||
# Test default initialization
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
return
|
||||
assert tokenize_strategy.max_length == 256
|
||||
assert tokenize_strategy.tokenizer.padding_side == "right"
|
||||
|
||||
# Test tokenization of a single string
|
||||
text = "Hello"
|
||||
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
||||
|
||||
assert tokens.ndim == 2
|
||||
assert attention_mask.ndim == 2
|
||||
assert tokens.shape == attention_mask.shape
|
||||
assert tokens.shape[1] == 256 # max_length
|
||||
|
||||
# Test tokenize_with_weights
|
||||
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
|
||||
assert len(weights) == 1
|
||||
assert torch.all(weights[0] == 1)
|
||||
|
||||
|
||||
def test_lumina_text_encoding_strategy():
|
||||
# Create strategies
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
return
|
||||
encoding_strategy = LuminaTextEncodingStrategy()
|
||||
|
||||
# Create a mock model
|
||||
mock_model = SimpleMockGemma2Model()
|
||||
|
||||
# Patch the isinstance check to accept our simple mock
|
||||
original_isinstance = isinstance
|
||||
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
||||
|
||||
def custom_isinstance(obj, class_or_tuple):
|
||||
if obj is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
return original_isinstance(obj, class_or_tuple)
|
||||
|
||||
mock_isinstance.side_effect = custom_isinstance
|
||||
|
||||
# Prepare sample text
|
||||
text = "Test encoding strategy"
|
||||
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
||||
|
||||
# Perform encoding
|
||||
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [mock_model], (tokens, attention_mask)
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
assert original_isinstance(hidden_states, torch.Tensor)
|
||||
assert original_isinstance(input_ids, torch.Tensor)
|
||||
assert original_isinstance(attention_masks, torch.Tensor)
|
||||
|
||||
# Check the shape of the second-to-last hidden state
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
# Test weighted encoding (which falls back to standard encoding for Lumina)
|
||||
weights = [torch.ones_like(tokens)]
|
||||
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
|
||||
)
|
||||
|
||||
# For the mock, we can't guarantee identical outputs since each call returns random tensors
|
||||
# Instead, check that the outputs have the same shape and are tensors
|
||||
assert hidden_states_w.shape == hidden_states.shape
|
||||
assert original_isinstance(hidden_states_w, torch.Tensor)
|
||||
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
|
||||
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
|
||||
|
||||
|
||||
def test_lumina_text_encoder_outputs_caching_strategy():
|
||||
# Create a temporary directory for caching
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a cache file path
|
||||
cache_file = os.path.join(tmpdir, "test_outputs.npz")
|
||||
|
||||
# Create the caching strategy
|
||||
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
|
||||
cache_to_disk=True,
|
||||
batch_size=1,
|
||||
skip_disk_cache_validity_check=False,
|
||||
)
|
||||
|
||||
# Create a mock class for ImageInfo
|
||||
class MockImageInfo:
|
||||
def __init__(self, caption, cache_path):
|
||||
self.caption = caption
|
||||
self.text_encoder_outputs_npz = cache_path
|
||||
|
||||
# Create a sample input info
|
||||
image_info = MockImageInfo("Test caption", cache_file)
|
||||
|
||||
# Simulate a batch
|
||||
batch = [image_info]
|
||||
|
||||
# Create mock strategies and model
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
return
|
||||
encoding_strategy = LuminaTextEncodingStrategy()
|
||||
mock_model = SimpleMockGemma2Model()
|
||||
|
||||
# Patch the isinstance check to accept our simple mock
|
||||
original_isinstance = isinstance
|
||||
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
||||
|
||||
def custom_isinstance(obj, class_or_tuple):
|
||||
if obj is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
return original_isinstance(obj, class_or_tuple)
|
||||
|
||||
mock_isinstance.side_effect = custom_isinstance
|
||||
|
||||
# Call cache_batch_outputs
|
||||
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
|
||||
|
||||
# Verify the npz file was created
|
||||
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
|
||||
|
||||
# Verify the is_disk_cached_outputs_expected method
|
||||
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
|
||||
|
||||
# Test loading from npz
|
||||
loaded_data = caching_strategy.load_outputs_npz(cache_file)
|
||||
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
|
||||
|
||||
|
||||
def test_lumina_latents_caching_strategy():
|
||||
# Create a temporary directory for caching
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Prepare a mock absolute path
|
||||
abs_path = os.path.join(tmpdir, "test_image.png")
|
||||
|
||||
# Use smaller image size for faster testing
|
||||
image_size = (64, 64)
|
||||
|
||||
# Create a smaller dummy image for testing
|
||||
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
|
||||
# Create the caching strategy
|
||||
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
|
||||
|
||||
# Create a simple mock VAE
|
||||
class MockVAE:
|
||||
def __init__(self):
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
def encode(self, x):
|
||||
# Return smaller encoded tensor for faster processing
|
||||
encoded = torch.randn(1, 4, 8, 8, device=x.device)
|
||||
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
|
||||
|
||||
# Prepare a mock batch
|
||||
class MockImageInfo:
|
||||
def __init__(self, path, image):
|
||||
self.absolute_path = path
|
||||
self.image = image
|
||||
self.image_path = path
|
||||
self.bucket_reso = image_size
|
||||
self.resized_size = image_size
|
||||
self.resize_interpolation = "lanczos"
|
||||
# Specify full path to the latents npz file
|
||||
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
|
||||
|
||||
batch = [MockImageInfo(abs_path, test_image)]
|
||||
|
||||
# Call cache_batch_latents
|
||||
mock_vae = MockVAE()
|
||||
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
|
||||
|
||||
# Generate the expected npz path
|
||||
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
|
||||
|
||||
# Verify the file was created
|
||||
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
|
||||
|
||||
# Verify is_disk_cached_latents_expected
|
||||
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
|
||||
|
||||
# Test loading from disk
|
||||
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
|
||||
assert len(loaded_data) == 5 # Check for 5 expected elements
|
||||
@@ -1,408 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from library.custom_offloading_utils import (
|
||||
synchronize_device,
|
||||
swap_weight_devices_cuda,
|
||||
swap_weight_devices_no_cuda,
|
||||
weighs_to_device,
|
||||
Offloader,
|
||||
ModelOffloader
|
||||
)
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, block_idx: int):
|
||||
super().__init__()
|
||||
self.block_idx = block_idx
|
||||
self.linear1 = nn.Linear(10, 5)
|
||||
self.linear2 = nn.Linear(5, 10)
|
||||
self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.linear2(x)
|
||||
x = self.seq(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self, num_blocks=16):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(i)
|
||||
for i in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
# Device Synchronization Tests
|
||||
@patch('torch.cuda.synchronize')
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_synchronize(mock_cuda_sync):
|
||||
device = torch.device('cuda')
|
||||
synchronize_device(device)
|
||||
mock_cuda_sync.assert_called_once()
|
||||
|
||||
@patch('torch.xpu.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
|
||||
def test_xpu_synchronize(mock_xpu_sync):
|
||||
device = torch.device('xpu')
|
||||
synchronize_device(device)
|
||||
mock_xpu_sync.assert_called_once()
|
||||
|
||||
@patch('torch.mps.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
|
||||
def test_mps_synchronize(mock_mps_sync):
|
||||
device = torch.device('mps')
|
||||
synchronize_device(device)
|
||||
mock_mps_sync.assert_called_once()
|
||||
|
||||
|
||||
# Weights to Device Tests
|
||||
def test_weights_to_device():
|
||||
# Create a simple model with weights
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
)
|
||||
|
||||
# Start with CPU tensors
|
||||
device = torch.device('cpu')
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
assert module.weight.device == device
|
||||
|
||||
# Move to mock CUDA device
|
||||
mock_device = torch.device('cuda')
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)):
|
||||
weighs_to_device(model, mock_device)
|
||||
|
||||
# Since we mocked the to() function, we can only verify modules were processed
|
||||
# but can't check actual device movement
|
||||
|
||||
|
||||
# Swap Weight Devices Tests
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_swap_weight_devices_cuda():
|
||||
device = torch.device('cuda')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
# Move layer to CUDA to move to CPU
|
||||
layer_to_cpu.to(device)
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
assert layer_to_cpu.device.type == 'cpu'
|
||||
assert layer_to_cuda.device.type == 'cuda'
|
||||
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
def test_swap_weight_devices_no_cuda(mock_sync_device):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = SimpleModel()
|
||||
|
||||
with patch('torch.Tensor.to', return_value=torch.zeros(1)):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
# Verify synchronize_device was called twice
|
||||
assert mock_sync_device.call_count == 2
|
||||
|
||||
|
||||
# Offloader Tests
|
||||
@pytest.fixture
|
||||
def offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return Offloader(
|
||||
num_blocks=4,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_offloader_init(offloader):
|
||||
assert offloader.num_blocks == 4
|
||||
assert offloader.blocks_to_swap == 2
|
||||
assert hasattr(offloader, 'thread_pool')
|
||||
assert offloader.futures == {}
|
||||
assert offloader.cuda_available == (offloader.device.type == 'cuda')
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_cuda')
|
||||
@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda')
|
||||
def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader):
|
||||
block_to_cpu = SimpleModel()
|
||||
block_to_cuda = SimpleModel()
|
||||
|
||||
# Force test for CUDA device
|
||||
offloader.cuda_available = True
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_not_called()
|
||||
|
||||
# Reset mocks
|
||||
mock_cuda.reset_mock()
|
||||
mock_no_cuda.reset_mock()
|
||||
|
||||
# Force test for non-CUDA device
|
||||
offloader.cuda_available = False
|
||||
offloader.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda)
|
||||
mock_cuda.assert_not_called()
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.Offloader.swap_weight_devices')
|
||||
def test_submit_move_blocks(mock_swap, offloader):
|
||||
blocks = [SimpleModel() for _ in range(4)]
|
||||
block_idx_to_cpu = 0
|
||||
block_idx_to_cuda = 2
|
||||
|
||||
# Mock the thread pool to execute synchronously
|
||||
future = MagicMock()
|
||||
future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda)
|
||||
offloader.thread_pool.submit = MagicMock(return_value=future)
|
||||
|
||||
offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
|
||||
# Check that the future is stored with the correct key
|
||||
assert block_idx_to_cuda in offloader.futures
|
||||
|
||||
|
||||
def test_wait_blocks_move(offloader):
|
||||
block_idx = 2
|
||||
|
||||
# Test with no future for the block
|
||||
offloader._wait_blocks_move(block_idx) # Should not raise
|
||||
|
||||
# Create a fake future and test waiting
|
||||
future = MagicMock()
|
||||
future.result.return_value = (0, block_idx)
|
||||
offloader.futures[block_idx] = future
|
||||
|
||||
offloader._wait_blocks_move(block_idx)
|
||||
|
||||
# Check that the future was removed
|
||||
assert block_idx not in offloader.futures
|
||||
future.result.assert_called_once()
|
||||
|
||||
|
||||
# ModelOffloader Tests
|
||||
@pytest.fixture
|
||||
def model_offloader():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
return ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
|
||||
def test_model_offloader_init(model_offloader):
|
||||
assert model_offloader.num_blocks == 4
|
||||
assert model_offloader.blocks_to_swap == 2
|
||||
assert hasattr(model_offloader, 'thread_pool')
|
||||
assert model_offloader.futures == {}
|
||||
assert len(model_offloader.remove_handles) > 0 # Should have registered hooks
|
||||
|
||||
|
||||
def test_create_backward_hook():
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
blocks = SimpleModel(4).blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test hook creation for swapping case (block 0)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 0)
|
||||
assert hook_swap is None
|
||||
|
||||
# Test hook creation for waiting case (block 1)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_wait is not None
|
||||
|
||||
# Test hook creation for no action case (block 3)
|
||||
hook_none = model_offloader.create_backward_hook(blocks, 3)
|
||||
assert hook_none is None
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_backward_hook_execution(mock_wait, mock_submit):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
blocks_to_swap = 2
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
model_offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=blocks_to_swap,
|
||||
device=device,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Test swapping hook (block 1)
|
||||
hook_swap = model_offloader.create_backward_hook(blocks, 1)
|
||||
assert hook_swap is not None
|
||||
hook_swap(model, torch.zeros(1), torch.zeros(1))
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test waiting hook (block 2)
|
||||
hook_wait = model_offloader.create_backward_hook(blocks, 2)
|
||||
assert hook_wait is not None
|
||||
hook_wait(model, torch.zeros(1), torch.zeros(1))
|
||||
assert mock_wait.call_count == 2
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.weighs_to_device')
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
@patch('library.custom_offloading_utils.clean_memory_on_device')
|
||||
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
|
||||
with patch.object(nn.Module, 'to'):
|
||||
model_offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Check that weighs_to_device was called for each block
|
||||
assert mock_weights_to_device.call_count == 4
|
||||
|
||||
# Check that synchronize_device and clean_memory_on_device were called
|
||||
mock_sync.assert_called_once_with(model_offloader.device)
|
||||
mock_clean.assert_called_once_with(model_offloader.device)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move')
|
||||
def test_wait_for_block(mock_wait, model_offloader):
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.wait_for_block(1)
|
||||
mock_wait.assert_not_called()
|
||||
|
||||
# Test with blocks_to_swap=2
|
||||
model_offloader.blocks_to_swap = 2
|
||||
block_idx = 1
|
||||
model_offloader.wait_for_block(block_idx)
|
||||
mock_wait.assert_called_once_with(block_idx)
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks')
|
||||
def test_submit_move_blocks(mock_submit, model_offloader):
|
||||
model = SimpleModel()
|
||||
blocks = model.blocks
|
||||
|
||||
# Test with blocks_to_swap=0
|
||||
model_offloader.blocks_to_swap = 0
|
||||
model_offloader.submit_move_blocks(blocks, 1)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
model_offloader.blocks_to_swap = 2
|
||||
|
||||
# Test within swap range
|
||||
block_idx = 1
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_called_once()
|
||||
|
||||
mock_submit.reset_mock()
|
||||
|
||||
# Test outside swap range
|
||||
block_idx = 3
|
||||
model_offloader.submit_move_blocks(blocks, block_idx)
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
|
||||
# Integration test for offloading in a realistic scenario
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_offloading_integration():
|
||||
device = torch.device('cuda')
|
||||
# Create a mini model with 4 blocks
|
||||
model = SimpleModel(5)
|
||||
model.to(device)
|
||||
blocks = model.blocks
|
||||
|
||||
# Initialize model offloader
|
||||
offloader = ModelOffloader(
|
||||
blocks=blocks,
|
||||
blocks_to_swap=2,
|
||||
device=device,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Prepare blocks for forward pass
|
||||
offloader.prepare_block_devices_before_forward(blocks)
|
||||
|
||||
# Simulate forward pass with offloading
|
||||
input_tensor = torch.randn(1, 10, device=device)
|
||||
x = input_tensor
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
# Wait for the current block to be ready
|
||||
offloader.wait_for_block(i)
|
||||
|
||||
# Process through the block
|
||||
x = block(x)
|
||||
|
||||
# Schedule moving weights for future blocks
|
||||
offloader.submit_move_blocks(blocks, i)
|
||||
|
||||
# Verify we get a valid output
|
||||
assert x.shape == (1, 10)
|
||||
assert not torch.isnan(x).any()
|
||||
|
||||
|
||||
# Error handling tests
|
||||
def test_offloader_assertion_error():
|
||||
with pytest.raises(AssertionError):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
layer_to_cuda = nn.Linear(10, 5) # Different class
|
||||
swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests when file is executed directly
|
||||
import sys
|
||||
|
||||
# Configure pytest command line arguments
|
||||
pytest_args = [
|
||||
"-v", # Verbose output
|
||||
"--color=yes", # Colored output
|
||||
__file__, # Run tests in this file
|
||||
]
|
||||
|
||||
# Add optional arguments from command line
|
||||
if len(sys.argv) > 1:
|
||||
pytest_args.extend(sys.argv[1:])
|
||||
|
||||
# Print info about test execution
|
||||
print(f"Running tests with PyTorch {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Run the tests
|
||||
sys.exit(pytest.main(pytest_args))
|
||||
@@ -1,6 +0,0 @@
|
||||
import fine_tune
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,6 +0,0 @@
|
||||
import flux_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,5 +0,0 @@
|
||||
import flux_train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the flux_train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,177 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
import argparse
|
||||
|
||||
from library import lumina_models, lumina_util
|
||||
from lumina_train_network import LuminaNetworkTrainer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lumina_trainer():
|
||||
return LuminaNetworkTrainer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
args = MagicMock()
|
||||
args.pretrained_model_name_or_path = "test_path"
|
||||
args.disable_mmap_load_safetensors = False
|
||||
args.use_flash_attn = False
|
||||
args.use_sage_attn = False
|
||||
args.fp8_base = False
|
||||
args.blocks_to_swap = None
|
||||
args.gemma2 = "test_gemma2_path"
|
||||
args.ae = "test_ae_path"
|
||||
args.cache_text_encoder_outputs = True
|
||||
args.cache_text_encoder_outputs_to_disk = False
|
||||
args.network_train_unet_only = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_accelerator():
|
||||
accelerator = MagicMock()
|
||||
accelerator.device = torch.device("cpu")
|
||||
accelerator.prepare.side_effect = lambda x, **kwargs: x
|
||||
accelerator.unwrap_model.side_effect = lambda x: x
|
||||
return accelerator
|
||||
|
||||
|
||||
def test_assert_extra_args(lumina_trainer, mock_args):
|
||||
train_dataset_group = MagicMock()
|
||||
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
val_dataset_group = MagicMock()
|
||||
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
|
||||
# Test with default settings
|
||||
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# Verify verify_bucket_reso_steps was called for both groups
|
||||
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
|
||||
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
|
||||
|
||||
# Check text encoder output caching
|
||||
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
|
||||
assert mock_args.cache_text_encoder_outputs is True
|
||||
|
||||
|
||||
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
|
||||
# Patch lumina_util methods
|
||||
with (
|
||||
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
|
||||
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
|
||||
patch("library.lumina_util.load_ae") as mock_load_ae,
|
||||
):
|
||||
# Create mock models
|
||||
mock_model = MagicMock(spec=lumina_models.NextDiT)
|
||||
mock_model.dtype = torch.float32
|
||||
mock_gemma2 = MagicMock()
|
||||
mock_ae = MagicMock()
|
||||
|
||||
mock_load_lumina_model.return_value = mock_model
|
||||
mock_load_gemma2.return_value = mock_gemma2
|
||||
mock_load_ae.return_value = mock_ae
|
||||
|
||||
# Test load_target_model
|
||||
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
|
||||
|
||||
# Verify calls and return values
|
||||
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
|
||||
assert gemma2_list == [mock_gemma2]
|
||||
assert ae == mock_ae
|
||||
assert model == mock_model
|
||||
|
||||
# Verify load calls
|
||||
mock_load_lumina_model.assert_called_once()
|
||||
mock_load_gemma2.assert_called_once()
|
||||
mock_load_ae.assert_called_once()
|
||||
|
||||
|
||||
def test_get_strategies(lumina_trainer, mock_args):
|
||||
# Test tokenize strategy
|
||||
try:
|
||||
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
|
||||
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
|
||||
# Test latents caching strategy
|
||||
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
|
||||
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
|
||||
|
||||
# Test text encoding strategy
|
||||
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
|
||||
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
|
||||
|
||||
|
||||
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
|
||||
# Call assert_extra_args to set train_gemma2
|
||||
train_dataset_group = MagicMock()
|
||||
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
val_dataset_group = MagicMock()
|
||||
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# With text encoder caching enabled
|
||||
mock_args.skip_cache_check = False
|
||||
mock_args.text_encoder_batch_size = 16
|
||||
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
||||
|
||||
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
|
||||
assert strategy.cache_to_disk is False # based on mock_args
|
||||
|
||||
# With text encoder caching disabled
|
||||
mock_args.cache_text_encoder_outputs = False
|
||||
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
||||
assert strategy is None
|
||||
|
||||
|
||||
def test_noise_scheduler(lumina_trainer, mock_args):
|
||||
device = torch.device("cpu")
|
||||
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
|
||||
|
||||
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
|
||||
assert noise_scheduler.num_train_timesteps == 1000
|
||||
assert hasattr(lumina_trainer, "noise_scheduler_copy")
|
||||
|
||||
|
||||
def test_sai_model_spec(lumina_trainer, mock_args):
|
||||
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
|
||||
mock_get_spec.return_value = "test_spec"
|
||||
spec = lumina_trainer.get_sai_model_spec(mock_args)
|
||||
assert spec == "test_spec"
|
||||
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
|
||||
|
||||
|
||||
def test_update_metadata(lumina_trainer, mock_args):
|
||||
metadata = {}
|
||||
lumina_trainer.update_metadata(metadata, mock_args)
|
||||
|
||||
assert "ss_weighting_scheme" in metadata
|
||||
assert "ss_logit_mean" in metadata
|
||||
assert "ss_logit_std" in metadata
|
||||
assert "ss_mode_scale" in metadata
|
||||
assert "ss_timestep_sampling" in metadata
|
||||
assert "ss_sigmoid_scale" in metadata
|
||||
assert "ss_model_prediction_type" in metadata
|
||||
assert "ss_discrete_flow_shift" in metadata
|
||||
|
||||
|
||||
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
|
||||
# Test with text encoder output caching, but not training text encoder
|
||||
mock_args.cache_text_encoder_outputs = True
|
||||
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False):
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is True
|
||||
|
||||
# Test with text encoder output caching and training text encoder
|
||||
with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True):
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is False
|
||||
|
||||
# Test with no text encoder output caching
|
||||
mock_args.cache_text_encoder_outputs = False
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is False
|
||||
@@ -1,6 +0,0 @@
|
||||
import sd3_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,5 +0,0 @@
|
||||
import sd3_train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the flux_train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,6 +0,0 @@
|
||||
import sdxl_train
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,6 +0,0 @@
|
||||
import sdxl_train_network
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,6 +0,0 @@
|
||||
import train_db
|
||||
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,5 +0,0 @@
|
||||
import train_network
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -1,5 +0,0 @@
|
||||
import train_textual_inversion
|
||||
|
||||
def test_syntax():
|
||||
# Very simply testing that the train_network imports without syntax errors
|
||||
assert True
|
||||
@@ -150,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
# cache latents with dataset
|
||||
# TODO use DataLoader to speed up
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents to disk.")
|
||||
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from library.utils import setup_logging, resize_image
|
||||
from library.utils import setup_logging, pil_resize
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -170,9 +170,12 @@ def process(args):
|
||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||
|
||||
if scale != 1.0:
|
||||
rw = int(w * scale + .5)
|
||||
rh = int(h * scale + .5)
|
||||
face_img = resize_image(face_img, w, h, rw, rh)
|
||||
w = int(w * scale + .5)
|
||||
h = int(h * scale + .5)
|
||||
if scale < 1.0:
|
||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
face_img = pil_resize(face_img, (w, h))
|
||||
cx = int(cx * scale + .5)
|
||||
cy = int(cy * scale + .5)
|
||||
fw = int(fw * scale + .5)
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import gc
|
||||
from typing import Dict, Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_safetensors(
|
||||
dit_path: str,
|
||||
vae_path: Optional[str] = None,
|
||||
clip_l_path: Optional[str] = None,
|
||||
clip_g_path: Optional[str] = None,
|
||||
t5xxl_path: Optional[str] = None,
|
||||
output_path: str = "merged_model.safetensors",
|
||||
device: str = "cpu",
|
||||
save_precision: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Merge multiple safetensors files into a single file
|
||||
|
||||
Args:
|
||||
dit_path: Path to the DiT/MMDiT model
|
||||
vae_path: Path to the VAE model
|
||||
clip_l_path: Path to the CLIP-L model
|
||||
clip_g_path: Path to the CLIP-G model
|
||||
t5xxl_path: Path to the T5-XXL model
|
||||
output_path: Path to save the merged model
|
||||
device: Device to load tensors to
|
||||
save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16')
|
||||
"""
|
||||
logger.info("Starting to merge safetensors files...")
|
||||
|
||||
# Convert save_precision string to torch dtype if specified
|
||||
if save_precision:
|
||||
target_dtype = str_to_dtype(save_precision)
|
||||
else:
|
||||
target_dtype = None
|
||||
|
||||
# 1. Get DiT metadata if available
|
||||
metadata = None
|
||||
try:
|
||||
with safe_open(dit_path, framework="pt") as f:
|
||||
metadata = f.metadata() # may be None
|
||||
if metadata:
|
||||
logger.info(f"Found metadata in DiT model: {metadata}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read metadata from DiT model: {e}")
|
||||
|
||||
# 2. Create empty merged state dict
|
||||
merged_state_dict = {}
|
||||
|
||||
# 3. Load and merge each model with memory management
|
||||
|
||||
# DiT/MMDiT - prefix: model.diffusion_model.
|
||||
# This state dict may have VAE keys.
|
||||
logger.info(f"Loading DiT model from {dit_path}")
|
||||
dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding DiT model with {len(dit_state_dict)} keys")
|
||||
for key, value in dit_state_dict.items():
|
||||
if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"model.diffusion_model.{key}"] = value
|
||||
# Free memory
|
||||
del dit_state_dict
|
||||
gc.collect()
|
||||
|
||||
# VAE - prefix: first_stage_model.
|
||||
# May be omitted if VAE is already included in DiT model.
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE model from {vae_path}")
|
||||
vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding VAE model with {len(vae_state_dict)} keys")
|
||||
for key, value in vae_state_dict.items():
|
||||
if key.startswith("first_stage_model."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"first_stage_model.{key}"] = value
|
||||
# Free memory
|
||||
del vae_state_dict
|
||||
gc.collect()
|
||||
|
||||
# CLIP-L - prefix: text_encoders.clip_l.
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading CLIP-L model from {clip_l_path}")
|
||||
clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys")
|
||||
for key, value in clip_l_state_dict.items():
|
||||
if key.startswith("text_encoders.clip_l.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del clip_l_state_dict
|
||||
gc.collect()
|
||||
|
||||
# CLIP-G - prefix: text_encoders.clip_g.
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading CLIP-G model from {clip_g_path}")
|
||||
clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys")
|
||||
for key, value in clip_g_state_dict.items():
|
||||
if key.startswith("text_encoders.clip_g.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del clip_g_state_dict
|
||||
gc.collect()
|
||||
|
||||
# T5-XXL - prefix: text_encoders.t5xxl.
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading T5-XXL model from {t5xxl_path}")
|
||||
t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys")
|
||||
for key, value in t5xxl_state_dict.items():
|
||||
if key.startswith("text_encoders.t5xxl.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del t5xxl_state_dict
|
||||
gc.collect()
|
||||
|
||||
# 4. Save merged state dict
|
||||
logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total")
|
||||
mem_eff_save_file(merged_state_dict, output_path, metadata)
|
||||
logger.info("Successfully merged safetensors files")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file")
|
||||
parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model")
|
||||
parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model")
|
||||
parser.add_argument("--clip_l", help="Path to the CLIP-L model")
|
||||
parser.add_argument("--clip_g", help="Path to the CLIP-G model")
|
||||
parser.add_argument("--t5xxl", help="Path to the T5-XXL model")
|
||||
parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model")
|
||||
parser.add_argument("--device", default="cpu", help="Device to load tensors to")
|
||||
parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
merge_safetensors(
|
||||
dit_path=args.dit,
|
||||
vae_path=args.vae,
|
||||
clip_l_path=args.clip_l,
|
||||
clip_g_path=args.clip_g,
|
||||
t5xxl_path=args.t5xxl,
|
||||
output_path=args.output,
|
||||
device=args.device,
|
||||
save_precision=args.save_precision,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,7 @@ import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from library.utils import setup_logging, resize_image
|
||||
from library.utils import setup_logging, pil_resize
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +22,14 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Select interpolation method
|
||||
if interpolation == 'lanczos4':
|
||||
pil_interpolation = Image.LANCZOS
|
||||
elif interpolation == 'cubic':
|
||||
pil_interpolation = Image.BICUBIC
|
||||
else:
|
||||
cv2_interpolation = cv2.INTER_AREA
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
||||
for filename in os.listdir(src_img_folder):
|
||||
@@ -55,7 +63,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
|
||||
# Resize image
|
||||
if cv2_interpolation:
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||
else:
|
||||
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
|
||||
else:
|
||||
new_height, new_width = img.shape[0:2]
|
||||
|
||||
@@ -101,8 +113,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
||||
parser.add_argument('--divisible_by', type=int,
|
||||
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
|
||||
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
||||
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
||||
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
||||
parser.add_argument('--copy_associated_files', action='store_true',
|
||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||
|
||||
@@ -157,7 +157,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
509
train_network.py
509
train_network.py
@@ -9,7 +9,6 @@ import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -69,20 +68,13 @@ class NetworkTrainer:
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
mean_grad_norm=None,
|
||||
mean_combined_norm=None,
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/average_key_norm"] = mean_norm
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
if mean_norm is not None:
|
||||
logs["norm/avg_key_norm"] = mean_norm
|
||||
if mean_grad_norm is not None:
|
||||
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||
if mean_combined_norm is not None:
|
||||
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lrs):
|
||||
@@ -108,7 +100,9 @@ class NetworkTrainer:
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
logs["lr/d*lr"] = (
|
||||
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
@@ -121,61 +115,21 @@ class NetworkTrainer:
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
):
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
|
||||
|
||||
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
|
||||
|
||||
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
|
||||
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
|
||||
|
||||
def accelerator_logging(
|
||||
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
step_value is for tensorboard, other values are for wandb
|
||||
"""
|
||||
tensorboard_tracker = None
|
||||
wandb_tracker = None
|
||||
other_trackers = []
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
tensorboard_tracker = accelerator.get_tracker("tensorboard")
|
||||
elif tracker.name == "wandb":
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
else:
|
||||
other_trackers.append(accelerator.get_tracker(tracker.name))
|
||||
|
||||
if tensorboard_tracker is not None:
|
||||
tensorboard_tracker.log(logs, step=step_value)
|
||||
|
||||
if wandb_tracker is not None:
|
||||
logs["global_step"] = global_step
|
||||
logs["epoch"] = epoch
|
||||
if val_step is not None:
|
||||
logs["val_step"] = val_step
|
||||
wandb_tracker.log(logs)
|
||||
|
||||
for tracker in other_trackers:
|
||||
tracker.log(logs, step=step_value)
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
@@ -265,7 +219,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -355,31 +309,28 @@ class NetworkTrainer:
|
||||
) -> torch.nn.Module:
|
||||
return accelerator.prepare(unet)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
|
||||
pass
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process a batch for the network
|
||||
@@ -389,18 +340,7 @@ class NetworkTrainer:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
|
||||
list_latents.append(chunk)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
@@ -414,13 +354,12 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
@@ -458,7 +397,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train,
|
||||
is_train=is_train
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
@@ -545,7 +484,7 @@ class NetworkTrainer:
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -620,9 +559,9 @@ class NetworkTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.new_cache_latents(vae, accelerator)
|
||||
val_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -670,10 +609,6 @@ class NetworkTrainer:
|
||||
return
|
||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||
|
||||
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
|
||||
# if not hasattr(network, "prepare_network"):
|
||||
# network.prepare_network = lambda args: None
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
@@ -766,7 +701,7 @@ class NetworkTrainer:
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
@@ -965,9 +900,7 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(
|
||||
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
|
||||
)
|
||||
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
@@ -1035,12 +968,11 @@ class NetworkTrainer:
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
"ss_validation_seed": args.validation_seed,
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
"ss_validation_seed": args.validation_seed,
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1066,7 +998,6 @@ class NetworkTrainer:
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
}
|
||||
|
||||
subsets_metadata = []
|
||||
@@ -1084,7 +1015,6 @@ class NetworkTrainer:
|
||||
"enable_wildcard": bool(subset.enable_wildcard),
|
||||
"caption_prefix": subset.caption_prefix,
|
||||
"caption_suffix": subset.caption_suffix,
|
||||
"resize_interpolation": subset.resize_interpolation,
|
||||
}
|
||||
|
||||
image_dir_or_metadata_file = None
|
||||
@@ -1233,6 +1163,10 @@ class NetworkTrainer:
|
||||
args.max_train_steps > initial_step
|
||||
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
epoch_to_start = 0
|
||||
if initial_step > 0:
|
||||
if args.skip_until_initial_step:
|
||||
@@ -1313,6 +1247,12 @@ class NetworkTrainer:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader))
|
||||
if args.max_validation_steps is not None
|
||||
else len(val_dataloader)
|
||||
)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
@@ -1331,57 +1271,13 @@ class NetworkTrainer:
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
)
|
||||
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
|
||||
validation_total_steps = validation_steps * len(validation_timesteps)
|
||||
original_args_min_timestep = args.min_timestep
|
||||
original_args_max_timestep = args.max_timestep
|
||||
|
||||
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
if accelerator.device.type == "cuda":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
elif accelerator.device.type == "xpu":
|
||||
gpu_rng_state = torch.xpu.get_rng_state()
|
||||
elif accelerator.device.type == "mps":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
else:
|
||||
gpu_rng_state = None
|
||||
python_rng_state = random.getstate()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
return (cpu_rng_state, gpu_rng_state, python_rng_state)
|
||||
|
||||
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
|
||||
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
if gpu_rng_state is not None:
|
||||
if accelerator.device.type == "cuda":
|
||||
torch.cuda.set_rng_state(gpu_rng_state)
|
||||
elif accelerator.device.type == "xpu":
|
||||
torch.xpu.set_rng_state(gpu_rng_state)
|
||||
elif accelerator.device.type == "mps":
|
||||
torch.cuda.set_rng_state(gpu_rng_state)
|
||||
random.setstate(python_rng_state)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
# TRAINING
|
||||
skipped_dataloader = None
|
||||
@@ -1398,25 +1294,25 @@ class NetworkTrainer:
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -1426,11 +1322,6 @@ class NetworkTrainer:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
if hasattr(network, "update_grad_norms"):
|
||||
network.update_grad_norms()
|
||||
if hasattr(network, "update_norms"):
|
||||
network.update_norms()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
@@ -1439,25 +1330,9 @@ class NetworkTrainer:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
weight_norms = network.weight_norms()
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1468,7 +1343,6 @@ class NetworkTrainer:
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
progress_bar.unpause()
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -1490,179 +1364,153 @@ class NetworkTrainer:
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
)
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
||||
should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
|
||||
# VALIDATION PER STEP
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step != 0 # Skip first step
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps",
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
)
|
||||
val_timesteps_step = 0
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
for timestep in validation_timesteps:
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||
train_unet=train_unet,
|
||||
)
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||
)
|
||||
|
||||
# if is_tracking:
|
||||
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
|
||||
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_timesteps_step += 1
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
progress_bar.unpause()
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps",
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps"
|
||||
)
|
||||
|
||||
val_timesteps_step = 0
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
for timestep in validation_timesteps:
|
||||
args.min_timestep = args.max_timestep = timestep
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||
)
|
||||
|
||||
# if is_tracking:
|
||||
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
|
||||
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_timesteps_step += 1
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_steps) + val_step
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"epoch": epoch + 1
|
||||
}
|
||||
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
progress_bar.unpause()
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
@@ -1682,7 +1530,6 @@ class NetworkTrainer:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
progress_bar.unpause()
|
||||
optimizer_train_fn()
|
||||
|
||||
# end of epoch
|
||||
@@ -1849,31 +1696,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--validation_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
|
||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
|
||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -382,7 +382,7 @@ class TextualInversionTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user