Compare commits

..

3 Commits

Author SHA1 Message Date
Kohya S
8f4ee8fc34 doc: update README for latest 2025-03-21 22:05:48 +09:00
Kohya S.
367f348430 Merge pull request #1964 from Nekotekina/main
Fix missing text encoder attn modules
2025-03-21 21:59:03 +09:00
Ivan Chikish
acdca2abb7 Fix [occasionally] missing text encoder attn modules
Should fix #1952
I added alternative name for CLIPAttention.
I have no idea why this name changed.
Now it should accept both names.
2025-03-01 20:35:45 +03:00
72 changed files with 2471 additions and 22768 deletions

View File

@@ -1,48 +0,0 @@
name: Test with pytest
on:
push:
branches:
- main
- dev
- sd3
pull_request:
branches:
- main
- dev
- sd3
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.10"] # Python versions to test
pytorch-version: ["2.4.0"] # PyTorch versions to test
steps:
- uses: actions/checkout@v4
with:
# https://woodruffw.github.io/zizmor/audits/#artipacked
persist-credentials: false
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install and update pip, setuptools, wheel
run: |
# Setuptools, wheel for compiling some packages
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
pip install -r requirements.txt
- name: Test with pytest
run: pytest # See pytest.ini for configuration

View File

@@ -1,11 +1,9 @@
---
# yamllint disable rule:line-length
name: Typos
on:
on: # yamllint disable-line rule:truthy
push:
branches:
- main
- dev
pull_request:
types:
- opened
@@ -18,9 +16,6 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
# https://woodruffw.github.io/zizmor/audits/#artipacked
persist-credentials: false
- name: typos-action
uses: crate-ci/typos@v1.28.1
uses: crate-ci/typos@v1.24.3

749
README.md
View File

@@ -1,744 +1,10 @@
This repository contains training, generation and utility scripts for Stable Diffusion.
## FLUX.1 and SD3 training (WIP)
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)
### Recent Updates
Jan 25, 2025:
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
- For details on how to set it up, please refer to the PR. The documentation will be updated as needed.
- It will be added to other scripts as well.
- As a current limitation, validation loss is not supported when `--block_to_swap` is specified, or when schedule-free optimizer is used.
Dec 15, 2024:
- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!
- Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`.
- Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler.
Dec 7, 2024:
- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds!
<!--
Also, the ControlNet training script for SD has been changed from `train_controlnet.py` to `train_control_net.py`.
- `train_controlnet.py` is still available, but it will be removed in the future.
-->
- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.
Dec 3, 2024:
-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training).
Dec 2, 2024:
- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details.
- Not fully tested. Feedback is welcome.
- 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution.
- Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required.
- Multi-GPU training is not tested.
Dec 1, 2024:
- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris!
- Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available.
- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO!
Nov 14, 2024:
- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM.
- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved.
- There may be bugs due to the significant changes. Feedback is welcome.
## FLUX.1 training
- [FLUX.1 LoRA training](#flux1-lora-training)
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
- [Distribution of timesteps](#distribution-of-timesteps)
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
- [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1)
- [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training)
- [FLUX.1 ControlNet training](#flux1-controlnet-training)
- [FLUX.1 OFT training](#flux1-oft-training)
- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model)
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
- [Convert FLUX LoRA](#convert-flux-lora)
- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint)
- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training)
- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1)
### FLUX.1 LoRA training
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options.
FLUX.1 model, CLIP-L, and T5XXL models are recommended to be in bf16/fp16 format. If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
Sample command is below. It will work with 24GB VRAM GPUs.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --network_train_unet_only
--optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml
--output_dir path/to/output/dir --output_name flux-lora-name
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0
```
(The command is multi-line for readability. Please combine it into one line.)
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
The trained LoRA model can be used with ComfyUI.
When training LoRA for Text Encoder (without `--network_train_unet_only`), more VRAM is required. Please refer to the settings below to reduce VRAM usage.
__Options for GPUs with less VRAM:__
By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks.
`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`.
Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below:
```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```
The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration.
The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below:
```
--blocks_to_swap 16
```
For GPUs with less than 10GB of VRAM, it is recommended to use an fp8 checkpoint for T5XXL. You can download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (please use without `scaled`).
10GB VRAM GPUs will work with 22 blocks swapped, and 8GB VRAM GPUs will work with 28 blocks swapped.
__`--split_mode` is deprecated. This option is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. If this option is specified and `--blocks_to_swap` is not specified, `--blocks_to_swap 18` is automatically enabled.__
#### Key Options for FLUX.1 LoRA training
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
- `--timestep_sampling` is the method to sample timesteps (0-1):
- `sigma`: sigma-based, same as SD3
- `uniform`: uniform random
- `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc.
- `shift`: shifts the value of sigmoid of normal distribution random number
- `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified.
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
- This option is effective even when`--timestep_sampling shift` is specified.
- Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
- `--model_prediction_type` is how to interpret and process the model prediction:
- `raw`: use as is, same as x-flux
- `additive`: add to noisy input
- `sigma_scaled`: apply sigma scaling, same as SD3
- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3).
- `--blocks_to_swap`. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`.
~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.~~
In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type) seems to work better.
The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type).
Other settings may work better, so please try different settings.
Other options are described below.
#### Distribution of timesteps
`--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below.
The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0):
![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6)
The difference between `--timestep_sampling sigmoid` and `--timestep_sampling uniform` (when `--timestep_sampling sigmoid` or `uniform` is specified, `--discrete_flow_shift` is ignored):
![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad)
The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored):
![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc)
#### Key Features for FLUX.1 LoRA training
1. CLIP-L and T5XXL LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
- The trained LoRA can be used with ComfyUI.
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|FLUX.1|`--network_train_unet_only`|-|o|
|FLUX.1 + CLIP-L|-|-|o (*2)|
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L LoRA training.
- *3: Not tested yet.
2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
- When specifying this option, the `--fp8_base` option is automatically enabled.
3. Split Q/K/V Projection Layers (Experimental):
- Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them.
- Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available).
- May increase expressiveness but also training time.
- The trained model is compatible with normal LoRA models in sd-scripts and can be used in environments like ComfyUI.
- Converting to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size.
4. T5 Attention Mask Application:
- T5 attention mask is applied when `--apply_t5_attn_mask` is specified.
- Now applies mask when encoding T5 and in the attention of Double and Single Blocks
- Affects fine-tuning, LoRA training, and inference in `flux_minimal_inference.py`.
5. Multi-resolution Training Support:
- FLUX.1 now supports multi-resolution training, even with caching latents to disk.
Technical details of Q/K/V split:
In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter.
The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.
#### Specify rank for each layer in FLUX.1
You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.
When network_args is not specified, the default value (`network_dim`) is applied, same as before.
|network_args|target layer|
|---|---|
|img_attn_dim|img_attn in DoubleStreamBlock|
|txt_attn_dim|txt_attn in DoubleStreamBlock|
|img_mlp_dim|img_mlp in DoubleStreamBlock|
|txt_mlp_dim|txt_mlp in DoubleStreamBlock|
|img_mod_dim|img_mod in DoubleStreamBlock|
|txt_mod_dim|txt_mod in DoubleStreamBlock|
|single_dim|linear1 and linear2 in SingleStreamBlock|
|single_mod_dim|modulation in SingleStreamBlock|
`"verbose=True"` is also available for debugging. It shows the rank of each layer.
example:
```
--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2"
"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True"
```
You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list.
example:
```
--network_args "in_dims=[4,2,2,2,4]"
```
Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`.
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.
#### Specify blocks to train in FLUX.1 LoRA training
You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks.
example:
```
--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37"
```
```
--network_args "train_double_block_indices=none" "train_single_block_indices=10-15"
```
If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual.
### FLUX.1 ControlNet training
We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options.
Sample command is below. It will work with 80GB VRAM GPUs.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors
--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16
--optimizer_type adamw8bit --learning_rate 2e-5
--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml
--output_dir /path/to/output/dir --output_name flux-cn
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed
```
For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested.
```
--blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk
```
The training can be done with 16GB VRAM GPUs with around 30 blocks swapped.
`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used.
### FLUX.1 OFT training
You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.
- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`.
- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc.
- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it.
- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`.
- `--network_args` specifies the hyperparameters of OFT. The following are valid:
- Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention.
Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`).
Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1.
```
--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3
--network_args "enable_all_linear=True" --learning_rate 1e-5
```
The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer.
### Inference for FLUX.1 with LoRA model
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
```
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```
### FLUX.1 fine-tuning
The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr!
__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__
Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
--lr_scheduler constant_with_warmup --max_grad_norm 0.0
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0
--fused_backward_pass --blocks_to_swap 8 --full_bf16
```
(The command is multi-line for readability. Please combine it into one line.)
Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available.
`--full_bf16` enables the training with bf16 (weights and gradients).
`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified.
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). The maximum value is 35.
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. This option cannot be used with `--blocks_to_swap`.
All these options are experimental and may change in the future.
The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training.
Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed.
The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results.
#### How to use block swap
There are two possible ways to use block swap. It is unknown which is better.
1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step.
The above command example is for this usage.
2. Swap many blocks to increase the batch size and shorten the training speed per data.
For example, swapping 35 blocks seems to increase the batch size to about 5. In this case, the training speed per data will be relatively faster than 1.
#### Training with <24GB VRAM GPUs
Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU.
T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning.
#### Key Features for FLUX.1 fine-tuning
1. Technical details of block swap:
- Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed.
- During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU.
- The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU.
- Since the transfer between CPU and GPU takes time, the training will be slower.
- `--blocks_to_swap` specify the number of blocks to swap.
- About 640MB of memory can be saved per block.
- (Update 1: Nov 12, 2024)
- The maximum number of blocks that can be swapped is 35.
- We are exchanging only the data of the weights (weight.data) in reference to the implementation of OneTrainer (thanks to OneTrainer). However, the mechanism of the exchange is a custom implementation.
- Since it takes time to free CUDA memory (torch.cuda.empty_cache()), we reuse the CUDA memory allocated to weight.data as it is and exchange the weights between modules.
- This shortens the time it takes to exchange weights between modules.
- Since the weights must be almost identical to be exchanged, FLUX.1 exchanges the weights between double blocks and single blocks.
- In SD3, all blocks are similar, but some weights are different, so there are weights that always remain on the GPU.
2. Sample Image Generation:
- Sample image generation during training is now supported.
- The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images.
- Specify options such as `--sample_prompts` and `--sample_every_n_epochs`.
- Note: It will be very slow when `--blocks_to_swap` is specified.
3. Experimental Memory-Efficient Saving:
- `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB).
- This is a custom implementation and may cause unexpected issues. Use with caution.
4. T5XXL Token Length Control:
- Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL.
- Default is 512 in dev and 256 in schnell models.
5. Multi-GPU Training Support:
- Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.
6. Disable mmap Load for Safetensors:
- `--disable_mmap_load_safetensors` option now works in `flux_train.py`.
- Speeds up model loading during training in WSL2.
- Effective in reducing memory usage when loading models during multi-GPU training.
### Extract LoRA from FLUX.1 Models
Script: `networks/flux_extract_lora.py`
Extracts LoRA from the difference between two FLUX.1 models.
Offers memory-efficient option with `--mem_eff_safe_open`.
CLIP-L LoRA is not supported.
### Convert FLUX LoRA
Script: `convert_flux_lora.py`
Converts LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based).
If you use LoRA in the inference environment, converting it to AI-toolkit format may reduce temporary memory usage.
Note that re-conversion will increase the size of LoRA.
CLIP-L/T5XXL LoRA is not supported.
### Merge LoRA to FLUX.1 checkpoint
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__
```
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
```
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below.
```
--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors
```
FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency.
An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving.
`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory):
- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine.
- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM.
- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'.
- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'.
`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`.
The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly.
### FLUX.1 Multi-resolution training
You can define multiple resolutions in the dataset configuration file.
The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution.
```
[general]
# define common settings here
flip_aug = true
color_aug = false
keep_tokens_separator= "|||"
shuffle_caption = false
caption_tag_dropout_rate = 0
caption_extension = ".txt"
[[datasets]]
# define the first resolution here
batch_size = 2
enable_bucket = true
resolution = [1024, 1024]
[[datasets.subsets]]
image_dir = "path/to/image/dir"
num_repeats = 1
[[datasets]]
# define the second resolution here
batch_size = 3
enable_bucket = true
resolution = [768, 768]
[[datasets.subsets]]
image_dir = "path/to/image/dir"
num_repeats = 1
[[datasets]]
# define the third resolution here
batch_size = 4
enable_bucket = true
resolution = [512, 512]
[[datasets.subsets]]
image_dir = "path/to/image/dir"
num_repeats = 1
```
### Convert Diffusers to FLUX.1
Script: `convert_diffusers_to_flux1.py`
Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder.
```
python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16
```
## SD3 training
SD3.5L/M training is now available.
### SD3 LoRA training
The script is `sd3_train_network.py`. See `--help` for options.
SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format.
Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L).
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py
--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_sd3 --network_dim 4 --network_train_unet_only
--optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml
--output_dir path/to/output/dir --output_name sd3-lora-name
```
(The command is multi-line for readability. Please combine it into one line.)
Like FLUX.1 training, the `--blocks_to_swap` option for memory reduction is available. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M.
Adafactor optimizer is also available.
`--cpu_offload_checkpointing` option is not available.
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
The trained LoRA model can be used with ComfyUI.
#### Key Options for SD3 LoRA training
Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training.
- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--clip_g` is the path to the CLIP-G model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model.
- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0.
- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training.
- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below.
- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled.
Other options are described below.
#### Key Features for SD3 LoRA training
1. CLIP-L, G and T5XXL LoRA Support:
- SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time.
- T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
- The trained LoRA can be used with ComfyUI.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|MMDiT|`--network_train_unet_only`|-|o|
|MMDiT + CLIP-L + CLIP-G|-|-|o (*2)|
|MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L and G LoRA training.
- *3: Not tested yet.
2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL.
- When specifying this option, the `--fp8_base` option is automatically enabled.
3. Split Q/K/V Projection Layers (Experimental):
- Same as FLUX.1.
4. CLIP-L/G and T5 Attention Mask Application:
- This function is planned to be implemented in the future.
5. Multi-resolution Training Support:
- Only for SD3.5M.
- Same as FLUX.1 for data preparation.
- If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__
6. Weighting scheme and training shift:
- The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1).
- The uniform distribution is the default. If you want to change the distribution, see `--help` for options.
- `--training_shift` is the shift value for the training distribution of timesteps.
- The effect of a shift in uniform distribution is shown in the figure below.
- ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350)
Technical details of multi-resolution training for SD3.5M:
SD3.5M does not use scaled positional embeddings for multi-resolution training, and is trained with a single positional embedding. Therefore, this feature is very experimental.
Generally, in multi-resolution training, the values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`.
This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf!
#### Specify rank for each layer in SD3 LoRA
You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.
When network_args is not specified, the default value (`network_dim`) is applied, same as before.
|network_args|target layer|
|---|---|
|context_attn_dim|attn in context_block|
|context_mlp_dim|mlp in context_block|
|context_mod_dim|adaLN_modulation in context_block|
|x_attn_dim|attn in x_block|
|x_mlp_dim|mlp in x_block|
|x_mod_dim|adaLN_modulation in x_block|
`"verbose=True"` is also available for debugging. It shows the rank of each layer.
example:
```
--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True"
```
You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list.
example:
```
--network_args "emb_dims=[2,3,4,5,6,7]"
```
Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`.
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`.
#### Specify blocks to train in SD3 LoRA training
You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`.
The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks.
example:
```
--network_args "train_block_indices=1,2,6-8"
```
### Inference for SD3 with LoRA model
The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options.
### SD3 fine-tuning
Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following:
- `--clip_g` is also available for SD3 fine-tuning.
- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning.
- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model.
- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__
- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning.
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training.
- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning.
- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training.
- CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option.
- If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only.
- The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available.
### Extract LoRA from SD3 Models
Not available yet.
### Convert SD3 LoRA
Not available yet.
### Merge LoRA to SD3 checkpoint
Not available yet.
---
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
Latest update: 2025-03-21 (Version 0.9.1)
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
@@ -882,6 +148,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
@@ -942,7 +213,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
@@ -952,7 +223,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
@@ -1011,7 +282,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
- mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。

View File

@@ -185,7 +185,7 @@ for img_file in img_files:
### Creating a dataset configuration file
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
```toml
[general]

View File

@@ -10,7 +10,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils, strategy_base
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -39,7 +39,6 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd
def train(args):
@@ -53,15 +52,7 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
@@ -90,11 +81,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -177,9 +167,8 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -205,9 +194,6 @@ def train(args):
else:
text_encoder.eval()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
@@ -230,11 +216,7 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -337,12 +319,7 @@ def train(args):
)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -367,21 +344,25 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -393,10 +374,11 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -408,7 +390,9 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -427,7 +411,7 @@ def train(args):
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
@@ -452,7 +436,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
@@ -465,7 +449,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -490,9 +474,7 @@ def train(args):
vae,
)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:

View File

@@ -1,232 +0,0 @@
# add caption to images by Florence-2
import argparse
import json
import os
import glob
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM
from library import device_utils, train_util, dataset_metadata_utils
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import tagger_utils
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
device = args.device if args.device is not None else device_utils.get_preferred_device()
if type(device) is str:
device = torch.device(device)
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
logger.info(f"device: {device}, dtype: {torch_dtype}")
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
support_flash_attn = False
try:
import flash_attn
support_flash_attn = True
except ImportError:
pass
if support_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
else:
logger.info(
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
)
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
from transformers.dynamic_module_utils import get_imports
from unittest.mock import patch
def fixed_get_imports(filename) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
# workaround for unnecessary flash_attn requirement
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
model.eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# 画像を読み込む
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
# load metadata if needed
if args.metadata is not None:
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# define preprocess_image function
def preprocess_image(image: Image.Image):
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
return inputs
# prepare DataLoader or something similar :)
# Loader returns: list of (image_path, processed_image_or_something, image_size)
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
else:
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
def run_batch(
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
caption_index: Optional[int] = None,
):
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
if args.debug:
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
max_new_tokens=args.max_new_tokens,
num_beams=args.num_beams,
)
if args.debug:
logger.info(f"generate done: {generated_ids.shape}")
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
if args.debug:
logger.info(f"decode done: {len(generated_texts)}")
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
caption_text = caption_text.strip().replace("<pad>", "")
original_caption_text = caption_text
if args.remove_mood:
p = caption_text.find("The overall ")
if p != -1:
caption_text = caption_text[:p].strip()
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(caption_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "caption" not in image_md:
image_md["caption"] = []
if caption_index is None:
image_md["caption"].append(caption_text)
else:
while len(image_md["caption"]) <= caption_index:
image_md["caption"].append("")
image_md["caption"][caption_index] = caption_text
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tCaption: {caption_text}")
if args.remove_mood and original_caption_text != caption_text:
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
for data_entry in tqdm(loader, smoothing=0.0):
b_imgs = data_entry
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.caption_index)
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
parser.add_argument(
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
)
parser.add_argument(
"--num_beams",
type=int,
default=3,
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
)
parser.add_argument(
"--caption_index",
type=int,
default=None,
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
tagger_utils.add_archive_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)

View File

@@ -180,7 +180,7 @@ def main(args):
# バッチへ追加
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_cache_path = npz_file_name
image_info.latents_npz = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image

View File

@@ -1,10 +1,7 @@
import argparse
import csv
import glob
import json
import os
from pathlib import Path
from typing import Any, Optional
import cv2
import numpy as np
@@ -13,18 +10,14 @@ from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
from library import dataset_metadata_utils
from library.utils import setup_logging
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
from library.utils import pil_resize
import tagger_utils
# from wd14 tagger
IMAGE_SIZE = 448
@@ -70,14 +63,13 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
size = image.size
image = preprocess_image(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (image, img_path, size)
return (image, img_path)
def collate_fn_remove_corrupted(batch):
@@ -91,10 +83,6 @@ def collate_fn_remove_corrupted(batch):
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
@@ -161,19 +149,15 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{"device_type": "GPU_FP32"}],
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
@@ -219,9 +203,7 @@ def main(args):
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert (
len(tags) == 2
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
@@ -234,15 +216,9 @@ def main(args):
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
tag_freq = {}
@@ -255,23 +231,19 @@ def main(args):
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
tags_index: Optional[int] = None,
):
imgs = np.array([im for _, im, _ in list_of_path_img_size])
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(list_of_path_img_size)]
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
for (image_path, _), prob in zip(path_imgs, probs):
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
@@ -293,7 +265,7 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
@@ -309,7 +281,7 @@ def main(args):
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
@@ -332,24 +304,12 @@ def main(args):
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
existing_content = None
if images_metadata is None:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
else:
image_md = images_metadata.get(image_path, None)
if image_md is not None:
tags = image_md.get("tags", None)
if tags is not None:
if tags_index is None and len(tags) > 0:
existing_content = tags[-1]
elif tags_index is not None and tags_index < len(tags):
existing_content = tags[tags_index]
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
if existing_content is not None:
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
@@ -359,46 +319,19 @@ def main(args):
# Create new tag_text
tag_text = caption_separator.join(existing_tags + new_tags)
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "tags" not in image_md:
image_md["tags"] = []
if tags_index is None:
image_md["tags"].append(tag_text)
else:
while len(image_md["tags"]) <= tags_index:
image_md["tags"].append("")
image_md["tags"][tags_index] = tag_text
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# load metadata if needed
if args.metadata is not None:
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
loader = torch.utils.data.DataLoader(
data = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
@@ -406,37 +339,35 @@ def main(args):
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
use_loader = True
else:
# make batch of image paths
loader = []
for i in range(0, len(image_paths), args.batch_size):
loader.append(image_paths[i : i + args.batch_size])
data = [[(None, ip)] for ip in image_paths]
for data_entry in tqdm(loader, smoothing=0.0):
if use_loader:
b_imgs = data_entry
else:
b_imgs = []
for image_path in data_entry:
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
size = image.size
image = preprocess_image(image)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image, size))
b_imgs.append((image_path, image))
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.tags_index)
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
@@ -449,7 +380,9 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -467,7 +400,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -506,7 +441,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
@@ -516,24 +453,20 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags",
action="store_true",
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag",
action="store_true",
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first",
action="store_true",
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
@@ -562,15 +495,6 @@ def setup_parser() -> argparse.ArgumentParser:
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
parser.add_argument(
"--tags_index",
type=int,
default=None,
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
)
tagger_utils.add_archive_arguments(parser)
return parser

View File

@@ -1,150 +0,0 @@
import argparse
import json
import math
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Union
import zipfile
import tarfile
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import dataset_metadata_utils, train_util
class ArchiveImageLoader:
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.archive_paths = archive_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.current_archive = None
self.archive_index = 0
self.image_index = 0
self.files = None
self.executor = ThreadPoolExecutor()
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
def __iter__(self):
return self
def __next__(self):
images = []
while len(images) < self.batch_size:
if self.current_archive is None:
if self.archive_index >= len(self.archive_paths):
if len(images) == 0:
raise StopIteration
else:
break # return the remaining images
if self.debug:
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
current_archive_path = self.archive_paths[self.archive_index]
if current_archive_path.endswith(".zip"):
self.current_archive = zipfile.ZipFile(current_archive_path)
self.files = self.current_archive.namelist()
elif current_archive_path.endswith(".tar"):
self.current_archive = tarfile.open(current_archive_path, "r")
self.files = self.current_archive.getnames()
else:
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
self.image_index = 0
# filter by image extensions
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
file = self.files[self.image_index]
archive_and_image_path = (
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
)
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
with archive.open(file) as f:
image = Image.open(f).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
# wait for all new_images to load to close the archive
new_images = [(image_path, future.result()) for image_path, future in new_images]
if self.image_index >= len(self.files):
self.current_archive.close()
self.current_archive = None
self.archive_index += 1
images.extend(new_images)
return [(image_path, image, size) for image_path, (image, size) in images]
class ImageLoader:
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.image_paths = image_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.image_index = 0
self.executor = ThreadPoolExecutor()
def __len__(self):
return math.ceil(len(self.image_paths) / self.batch_size)
def __iter__(self):
return self
def __next__(self):
if self.image_index >= len(self.image_paths):
raise StopIteration
images = []
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
def load_image(file):
image = Image.open(file).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
image_path = self.image_paths[self.image_index]
images.append((image_path, self.executor.submit(load_image, image_path)))
self.image_index += 1
images = [(image_path, future.result()) for image_path, future in images]
return [(image_path, image, size) for image_path, (image, size) in images]
def add_archive_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--metadata",
type=str,
default=None,
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
)
parser.add_argument(
"--load_archive",
action="store_true",
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
)

View File

@@ -1,576 +0,0 @@
# Minimum Inference Code for FLUX
import argparse
import datetime
import math
import os
import random
from typing import Callable, List, Optional
import einops
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image
import accelerate
from transformers import CLIPTextModel
from safetensors.torch import load_file
from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
from networks import oft_flux
init_ipex()
from library.utils import setup_logging, str_to_dtype
setup_logging()
import logging
logger = logging.getLogger(__name__)
import networks.lora_flux as lora_flux
from library import flux_models, flux_utils, sd3_utils, strategy_flux
def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
neg_txt: Optional[torch.Tensor] = None,
neg_vec: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
# this is ignored for schnell
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# prepare classifier free guidance
if neg_txt is not None and neg_vec is not None:
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
b_txt = torch.cat([neg_txt, txt], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0)
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
else:
b_t5_attn_mask = None
else:
b_img_ids = img_ids
b_txt_ids = txt_ids
b_txt = txt
b_vec = vec
b_t5_attn_mask = t5_attn_mask
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# classifier free guidance
if neg_txt is not None and neg_vec is not None:
b_img = torch.cat([img, img], dim=0)
else:
b_img = img
pred = model(
img=b_img,
img_ids=b_img_ids,
txt=b_txt,
txt_ids=b_txt_ids,
y=b_vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=b_t5_attn_mask,
)
# classifier free guidance
if neg_txt is not None and neg_vec is not None:
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
img = img + (t_prev - t_curr) * pred
return img
def do_sample(
accelerator: Optional[accelerate.Accelerator],
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
l_pooled: torch.Tensor,
t5_out: torch.Tensor,
txt_ids: torch.Tensor,
num_steps: int,
guidance: float,
t5_attn_mask: Optional[torch.Tensor],
is_schnell: bool,
device: torch.device,
flux_dtype: torch.dtype,
neg_l_pooled: Optional[torch.Tensor] = None,
neg_t5_out: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
logger.info(f"num_steps: {num_steps}")
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
# denoise initial noise
if accelerator:
with accelerator.autocast(), torch.no_grad():
x = denoise(
model,
img,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps,
guidance,
t5_attn_mask,
neg_t5_out,
neg_l_pooled,
neg_t5_attn_mask,
cfg_scale,
)
else:
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
x = denoise(
model,
img,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps,
guidance,
t5_attn_mask,
neg_t5_out,
neg_l_pooled,
neg_t5_attn_mask,
cfg_scale,
)
return x
def generate_image(
model,
clip_l: CLIPTextModel,
t5xxl,
ae,
prompt: str,
seed: Optional[int],
image_width: int,
image_height: int,
steps: Optional[int],
guidance: float,
negative_prompt: Optional[str],
cfg_scale: float,
):
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")
# make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=device,
dtype=noise_dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
# prepare img and img ids
# this is needed only for img2img
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
# if img.shape[0] == 1 and bs > 1:
# img = repeat(img, "1 ... -> bs ...", bs=bs)
# txt2img only needs img_ids
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
# prepare fp8 models
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
clip_l.fp8_prepared = True
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)
t5xxl.fp8_prepared = True
# prepare embeddings
logger.info("Encoding prompts...")
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
def encode(prpt: str):
tokens_and_masks = tokenize_strategy.tokenize(prpt)
with torch.no_grad():
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if is_fp8(t5xxl_dtype):
with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
else:
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
return l_pooled, t5_out, txt_ids, t5_attn_mask
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
if negative_prompt:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
else:
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
# NaN check
if torch.isnan(l_pooled).any():
raise ValueError("NaN in l_pooled")
if torch.isnan(t5_out).any():
raise ValueError("NaN in t5_out")
if args.offload:
clip_l = clip_l.cpu()
t5xxl = t5xxl.cpu()
# del clip_l, t5xxl
device_utils.clean_memory()
# generate image
logger.info("Generating image...")
model = model.to(device)
if steps is None:
steps = 4 if is_schnell else 50
img_ids = img_ids.to(device)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
x = do_sample(
accelerator,
model,
noise,
img_ids,
l_pooled,
t5_out,
txt_ids,
steps,
guidance,
t5_attn_mask,
is_schnell,
device,
flux_dtype,
neg_l_pooled,
neg_t5_out,
neg_t5_attn_mask,
cfg_scale,
)
if args.offload:
model = model.cpu()
# del model
device_utils.clean_memory()
# unpack
x = x.float()
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
# decode
logger.info("Decoding image...")
ae = ae.to(device)
with torch.no_grad():
if is_fp8(ae_dtype):
with accelerator.autocast():
x = ae.decode(x)
else:
with torch.autocast(device_type=device.type, dtype=ae_dtype):
x = ae.decode(x)
if args.offload:
ae = ae.cpu()
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# save image
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
img.save(output_path)
logger.info(f"Saved image to {output_path}")
if __name__ == "__main__":
target_height = 768 # 1024
target_width = 1360 # 1024
# steps = 50 # 28 # 50
# guidance_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--ae", type=str, required=False)
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
guidance_scale = args.guidance
def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
dtype = str_to_dtype(args.dtype)
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
loading_device = "cpu" if args.offload else device
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
if any(use_fp8):
accelerator = accelerate.Accelerator(mixed_precision="bf16")
else:
accelerator = None
# load clip_l
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
t5xxl.eval()
# if is_fp8(clip_l_dtype):
# clip_l = accelerator.prepare(clip_l)
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)
# DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
# if is_fp8(flux_dtype):
# model = accelerator.prepare(model)
# if args.offload:
# model = model.to("cpu")
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# AE
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
ae.eval()
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)
# LoRA
lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break
module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else:
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = target_width
height = target_height
steps = None
guidance = args.guidance
cfg_scale = args.cfg_scale
while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
)
prompt = input()
if prompt == "":
break
# parse options
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("g"):
guidance = float(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
logger.info("Done!")

View File

@@ -1,860 +0,0 @@
# training with captions
# Swap blocks between CPU and GPU:
# This implementation is inspired by and based on the work of 2kpr.
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
# The original idea has been adapted and extended to fit the current project's needs.
# Key features:
# - CPU offloading during forward and backward passes
# - Use of fused optimizer and grad_hook for efficient gradient processing
# - Per-block fused optimizer instances
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import math
import os
from multiprocessing import Value
import time
from typing import List, Optional, Tuple, Union
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
# sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning(
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
)
args.gradient_checkpointing = True
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, (
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# prepare tokenize strategy
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
# load clip_l, t5xxl for caching text encoder outputs
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
clip_l.eval()
t5xxl.eval()
clip_l.requires_grad_(False)
t5xxl.requires_grad_(False)
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad here
clip_l.to(accelerator.device)
t5xxl.to(accelerator.device)
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
accelerator.wait_for_everyone()
# now we can delete Text Encoders to free memory
clip_l = None
t5xxl = None
clean_memory_on_device(accelerator.device)
# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
flux.requires_grad_(True)
# block swap
# backward compatibility
if args.blocks_to_swap is None:
blocks_to_swap = args.double_blocks_to_swap or 0
if args.single_blocks_to_swap is not None:
blocks_to_swap += args.single_blocks_to_swap // 2
if blocks_to_swap > 0:
logger.warning(
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
)
logger.info(
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
)
args.blocks_to_swap = blocks_to_swap
del blocks_to_swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)
training_models = []
params_to_optimize = []
training_models.append(flux)
name_and_params = list(flux.named_parameters())
# single param group for now
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
param_names = [[n for n, _ in name_and_params]]
# calculate number of trainable parameters
n_params = 0
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.
# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(flux.named_parameters())
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})
num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
if args.blockwise_fused_optimizers:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
flux.to(weight_dtype)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype) # TODO check works with fp16 or not
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
flux.to(weight_dtype)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype)
# if we don't cache text encoder outputs, move them to device
if not args.cache_text_encoder_outputs:
clip_l.to(accelerator.device)
t5xxl.to(accelerator.device)
clean_memory_on_device(accelerator.device)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
if is_swapping_blocks:
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
# backward
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
)
optimizer_train_fn()
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
)
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
optimizer_train_fn()
is_main_process = accelerator.is_main_process
# if is_main_process:
flux = accelerator.unwrap_model(flux)
accelerator.end_training()
optimizer_eval_fn()
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)
parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--single_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,878 +0,0 @@
# training with captions
# Swap blocks between CPU and GPU:
# This implementation is inspired by and based on the work of 2kpr.
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
# The original idea has been adapted and extended to fit the current project's needs.
# Key features:
# - CPU offloading during forward and backward passes
# - Use of fused optimizer and grad_hook for efficient gradient processing
# - Per-block fused optimizer instances
import argparse
import copy
import math
import os
import time
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Value
from typing import List, Optional, Tuple, Union
import toml
import torch
import torch.nn as nn
from tqdm import tqdm
from library import utils
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
from accelerate.utils import set_seed
import library.train_util as train_util
from library import (
deepspeed_utils,
flux_train_utils,
flux_utils,
strategy_base,
strategy_flux,
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
BlueprintGenerator,
ConfigSanitizer,
)
from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
# sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning(
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
)
args.gradient_checkpointing = True
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, (
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir, args.conditioning_data_dir, args.caption_extension
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# prepare tokenize strategy
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
# load clip_l, t5xxl for caching text encoder outputs
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
clip_l.eval()
t5xxl.eval()
clip_l.requires_grad_(False)
t5xxl.requires_grad_(False)
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad here
clip_l.to(accelerator.device)
t5xxl.to(accelerator.device)
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
accelerator.wait_for_everyone()
# now we can delete Text Encoders to free memory
clip_l = None
t5xxl = None
clean_memory_on_device(accelerator.device)
# load FLUX
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
flux.requires_grad_(False)
# load controlnet
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
controlnet = flux_utils.load_controlnet(
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
)
controlnet.train()
if args.gradient_checkpointing:
if not args.deepspeed:
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
# block swap
# backward compatibility
if args.blocks_to_swap is None:
blocks_to_swap = args.double_blocks_to_swap or 0
if args.single_blocks_to_swap is not None:
blocks_to_swap += args.single_blocks_to_swap // 2
if blocks_to_swap > 0:
logger.warning(
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
)
logger.info(
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
)
args.blocks_to_swap = blocks_to_swap
del blocks_to_swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
# ControlNet only has two blocks, so we can keep it on GPU
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
else:
flux.to(accelerator.device)
if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)
training_models = []
params_to_optimize = []
training_models.append(controlnet)
name_and_params = list(controlnet.named_parameters())
# single param group for now
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
param_names = [[n for n, _ in name_and_params]]
# calculate number of trainable parameters
n_params = 0
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.
# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(controlnet.named_parameters())
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})
num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
if args.blockwise_fused_optimizers:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
flux.to(weight_dtype)
controlnet.to(weight_dtype)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype) # TODO check works with fp16 or not
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
flux.to(weight_dtype)
controlnet.to(weight_dtype)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype)
# if we don't cache text encoder outputs, move them to device
if not args.cache_text_encoder_outputs:
clip_l.to(accelerator.device)
t5xxl.to(accelerator.device)
clean_memory_on_device(accelerator.device)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks])
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
if is_swapping_blocks:
flux.prepare_block_swap_before_forward()
# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = (
flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width)
.to(device=accelerator.device)
.to(weight_dtype)
)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
block_samples, block_single_samples = controlnet(
img=packed_noisy_model_input,
img_ids=img_ids,
controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype),
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# calculate loss
loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
# backward
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator,
args,
None,
global_step,
flux,
ae,
[clip_l, t5xxl],
sample_prompts_te_outputs,
controlnet=controlnet,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(controlnet),
)
optimizer_train_fn()
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(controlnet),
)
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
optimizer_train_fn()
is_main_process = accelerator.is_main_process
# if is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
accelerator.end_training()
optimizer_eval_fn()
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)
parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--single_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,585 +0,0 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
# deprecated split_mode option
if args.split_mode:
if args.blocks_to_swap is not None:
logger.warning(
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
)
else:
logger.warning(
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
)
args.blocks_to_swap = 18 # 18 is safe for most cases
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 FLUX model")
else:
logger.info(
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.train_clip_l and not self.train_t5xxl:
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
else:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # both CLIP-L and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_clip_l, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
is_partial=self.train_clip_l or self.train_t5xxl,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)
if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return
"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
self.flux_upper = flux_upper
self.flux_lower = flux_lower
self.target_device = device
def prepare_block_swap_before_forward(self):
pass
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: flux_models.Flux,
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred
model_pred = call_dit(
img=packed_noisy_model_input,
img_ids=img_ids,
t5_out=t5_out,
txt_ids=txt_ids,
l_pooled=l_pooled,
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
unet.prepare_block_swap_before_forward()
with torch.no_grad():
model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices],
img_ids=img_ids[diff_output_pr_indices],
t5_out=t5_out[diff_output_pr_indices],
txt_ids=txt_ids[diff_output_pr_indices],
l_pooled=l_pooled[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
args,
model_pred_prior,
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
def update_metadata(self, metadata, args):
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_guidance_scale"] = args.guidance_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0: # CLIP-L
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
flux: flux_models.Flux = unet
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
return flux
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
"--split_mode",
action="store_true",
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = FluxNetworkTrainer()
trainer.train(args)

View File

@@ -43,8 +43,8 @@ from diffusers import (
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from accelerate import init_empty_weights
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@@ -58,7 +58,6 @@ import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
@@ -353,8 +352,8 @@ class PipelineLike:
self.token_replacements_list.append({})
# ControlNet
self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5
self.control_net_lllites: List[ControlNetLLLite] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
self.gradual_latent: GradualLatent = None
@@ -543,7 +542,7 @@ class PipelineLike:
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.control_net_lllites or (self.control_nets and self.is_sdxl):
if self.control_net_lllites:
# ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
@@ -732,12 +731,7 @@ class PipelineLike:
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
if not self.is_sdxl:
guided_hints = original_control_net.get_guided_hints(
self.control_nets, num_latent_input, batch_size, clip_guide_images
)
else:
clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1]
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
if self.control_net_lllites:
@@ -799,7 +793,7 @@ class PipelineLike:
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo
# disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo
if self.control_net_lllites:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
@@ -808,16 +802,9 @@ class PipelineLike:
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
if self.control_nets and self.is_sdxl:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
each_control_net_enabled[j] = False
# predict the noise residual
if self.control_nets and self.control_net_enabled and not self.is_sdxl:
if self.control_nets and self.control_net_enabled:
if regional_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
@@ -836,31 +823,6 @@ class PipelineLike:
text_embeddings,
text_emb_last,
).sample
elif self.control_nets:
input_resi_add_list = []
mid_add_list = []
for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled):
if not enbld:
continue
input_resi_add, mid_add = control_net(
latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images
)
input_resi_add_list.append(input_resi_add)
mid_add_list.append(mid_add)
if len(input_resi_add_list) == 0:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
if len(input_resi_add_list) > 1:
# get mean of input_resi_add_list and mid_add_list
input_resi_add_mean = []
for i in range(len(input_resi_add_list[0])):
input_resi_add_mean.append(
torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0))
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
@@ -1863,37 +1825,16 @@ def main(args):
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
if not is_sdxl:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
else:
for i, model_file in enumerate(args.control_net_models):
multiplier = (
1.0
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
else args.control_net_multipliers[i]
)
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
logger.info(f"loading SDXL ControlNet: {model_file}")
from safetensors.torch import load_file
state_dict = load_file(model_file)
logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}")
with init_empty_weights():
control_net = SdxlControlNet(multiplier=multiplier)
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)
control_nets.append((control_net, ratio))
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:

View File

@@ -2,32 +2,6 @@ import math
import torch
from transformers import Adafactor
# stochastic rounding for bfloat16
# The implementation was provided by 2kpr. Thank you very much!
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
"""
copies source into target using stochastic rounding
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
"""
# create a random 16 bit integer
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -74,7 +48,7 @@ def adafactor_step_param(self, p, group):
lr = Adafactor._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
@@ -104,12 +78,7 @@ def adafactor_step_param(self, p, group):
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
@@ -132,7 +101,6 @@ def adafactor_step(self, closure=None):
return loss
def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)

View File

@@ -10,7 +10,13 @@ import json
from pathlib import Path
# from toolz import curry
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import (
List,
Optional,
Sequence,
Tuple,
Union,
)
import toml
import voluptuous
@@ -72,9 +78,6 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
@dataclass
@@ -101,11 +104,11 @@ class ControlNetSubsetParams(BaseSubsetParams):
@dataclass
class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
@dataclass
@@ -117,7 +120,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
@@ -195,7 +199,6 @@ class ConfigSanitizer:
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -237,8 +240,6 @@ class ConfigSanitizer:
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
}
@@ -467,136 +468,118 @@ class BlueprintGenerator:
return default_value
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
extra_dataset_params = {}
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": True}
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
datasets.append(dataset)
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
continue
# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
# if the dataset isn't setting a validation split, there is no current validation dataset
if dataset_blueprint.params.validation_split == 0.0:
continue
extra_dataset_params = {}
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": False}
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
info += "\n"
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
val_datasets.append(dataset)
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
"""
),
" ",
)
def print_info(_datasets, dataset_type: str):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")
if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of {dataset_type} {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""), " ")
if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
logger.info(info)
print_info(datasets, "Dataset")
if len(val_datasets) > 0:
print_info(val_datasets, "Validation Dataset")
logger.info(f"{info}")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Prepare dataset {i}]")
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
for i, dataset in enumerate(val_datasets):
logger.info(f"[Prepare validation dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)
return DatasetGroup(datasets)
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):

View File

@@ -1,227 +0,0 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# print(module_to_cpu.__class__, module_to_cuda.__class__)
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
else:
if module_to_cuda.weight.data.device.type != device.type:
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
"""
not tested
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
class Offloader:
"""
common offloading class
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
self.debug = debug
self.thread_pool = ThreadPoolExecutor(max_workers=1)
self.futures = {}
self.cuda_available = device.type == "cuda"
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
)
def _wait_blocks_move(self, block_idx):
if block_idx not in self.futures:
return
if self.debug:
print(f"Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
_, bidx_to_cuda = future.result()
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
def __del__(self):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
if self.debug:
print(f"Backward hook for block {block_index}")
if swapping:
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
if waiting:
self._wait_blocks_move(block_idx_to_wait)
return None
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if self.debug:
print("Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)

View File

@@ -1,9 +1,7 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from .utils import setup_logging
@@ -65,7 +63,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
@@ -76,13 +74,13 @@ def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_sched
return loss
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
def get_snr_scale(timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
@@ -91,14 +89,14 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
return scale
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
@@ -455,7 +453,7 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
@@ -468,7 +466,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.Float
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
@@ -484,7 +482,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> to
return noise
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
def apply_masked_loss(loss, batch):
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel

View File

@@ -1,58 +0,0 @@
import os
import json
from typing import Any, Optional
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
METADATA_VERSION = [1, 0, 0]
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
ARCHIVE_PATH_SEPARATOR = "////"
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
if os.path.exists(metadata_file):
logger.info(f"loading metadata file: {metadata_file}")
with open(metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
if not create_new:
return None
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
metadata = {"format_version": VERSION_STRING, "images": {}}
return metadata
def is_archive_path(archive_and_image_path: str) -> bool:
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
def get_inner_path(archive_and_image_path: str) -> str:
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
def get_archive_digest(archive_and_image_path: str) -> str:
"""
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
"""
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"

File diff suppressed because it is too large Load Diff

View File

@@ -1,619 +0,0 @@
import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, PartialState
from transformers import CLIPTextModel
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from .utils import setup_logging, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
# region sample images
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
flux,
ae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
flux,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
flux,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
flux: flux_models.Flux,
text_encoders: Optional[List[CLIPTextModel]],
ae: flux_models.AutoEncoder,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = ae.device # will be on cpu
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast(), torch.no_grad():
x = ae.decode(x)
ae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_img,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
# endregion
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":
pass
elif args.model_prediction_type == "additive":
# add the model_pred to the noisy_model_input
model_pred = model_pred + noisy_model_input
elif args.model_prediction_type == "sigma_scaled":
# apply sigma scaling
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, weighting
def save_models(
ckpt_path: str,
flux: flux_models.Flux,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("", flux.state_dict())
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
def save_flux_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_flux_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
flux: flux_models.Flux,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# endregion
def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--clip_l",
type=str,
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument(
"--t5xxl",
type=str,
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors"
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

View File

@@ -1,488 +0,0 @@
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from accelerate import init_empty_weights
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Args:
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
Returns:
Tuple[bool, bool, Tuple[int, int], List[str]]:
- bool: Diffusersかどうかを示すフラグ。
- bool: Schnellかどうかを示すフラグ。
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
- List[str]: チェックポイントに含まれるキーのリスト。
"""
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
keys = []
for ckpt_path in ckpt_paths:
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())
# if the key has annoying prefix, remove it
if keys[0].startswith("model.diffusion_model."):
keys = [key.replace("model.diffusion_model.", "") for key in keys]
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
# check number of double and single blocks
if not is_diffusers:
max_double_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
)
max_single_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
)
else:
max_double_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
]
)
max_single_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
]
)
num_double_blocks = max_double_block_index + 1
num_single_blocks = max_single_block_index + 1
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params
# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_controlnet(
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
):
logger.info("Building ControlNet")
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
if ckpt_path is not None:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = controlnet.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded ControlNet: {info}")
return controlnet
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> CLIPTextModel:
logger.info("Building CLIP-L")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
"architectures": ["CLIPModel"],
"initializer_factor": 1.0,
"logit_scale_init_value": 2.6592,
"model_type": "clip",
"projection_dim": 768,
# "text_config": {
"_name_or_path": "",
"add_cross_attention": False,
"architectures": None,
"attention_dropout": 0.0,
"bad_words_ids": None,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": None,
"decoder_start_token_id": None,
"diversity_penalty": 0.0,
"do_sample": False,
"dropout": 0.0,
"early_stopping": False,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"finetuning_task": None,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": False,
"is_encoder_decoder": False,
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 77,
"min_length": 0,
"model_type": "clip_text_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": False,
"output_hidden_states": False,
"output_scores": False,
"pad_token_id": 1,
"prefix": None,
"problem_type": None,
"projection_dim": 768,
"pruned_heads": {},
"remove_invalid_values": False,
"repetition_penalty": 1.0,
"return_dict": True,
"return_dict_in_generate": False,
"sep_token_id": None,
"task_specific_params": None,
"temperature": 1.0,
"tie_encoder_decoder": False,
"tie_word_embeddings": True,
"tokenizer_class": None,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": None,
"torchscript": False,
"transformers_version": "4.16.0.dev0",
"use_bfloat16": False,
"vocab_size": 49408,
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
# },
# "text_config_dict": {
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"projection_dim": 768,
# },
# "torch_dtype": "float32",
# "transformers_version": None,
}
config = CLIPConfig(**CLIPL_CONFIG)
with init_empty_weights():
clip = CLIPTextModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_t5xxl(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.41.2",
"use_cache": true,
"vocab_size": 32128
}
"""
config = json.loads(T5_CONFIG_JSON)
config = T5Config(**config)
with init_empty_weights():
t5xxl = T5EncoderModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
# region Diffusers
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(num_double_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(num_single_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
return diffusers_to_bfl_map
def convert_diffusers_sd_to_bfl(
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for diffusers_key, tensor in diffusers_sd.items():
if diffusers_key in diffusers_to_bfl_map:
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
return flux_sd
# endregion

View File

@@ -6,10 +6,8 @@ import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
@@ -57,18 +55,12 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -121,12 +113,7 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
):
"""
sd3: only supports "m", flux: only supports "dev"
"""
# if state_dict is None, hash is not calculated
metadata = {}
@@ -139,13 +126,6 @@ def build_metadata(
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -162,12 +142,9 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if flux is not None:
# Flux
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -225,7 +202,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None or flux is not None:
if sdxl:
reso = 1024
elif v2 and v_parameterization:
reso = 768
@@ -236,9 +213,7 @@ def build_metadata(
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
if v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
@@ -261,7 +236,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -275,7 +250,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

File diff suppressed because it is too large Load Diff

View File

@@ -1,945 +0,0 @@
import argparse
import math
import os
import toml
import json
import time
from typing import Dict, List, Optional, Tuple, Union
import torch
from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
# from transformers import CLIPTokenizer
# from library import model_util
# , sdxl_model_util, train_util, sdxl_original_unet
# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_base, train_util
def save_models(
ckpt_path: str,
mmdit: Optional[sd3_models.MMDiT],
vae: Optional[sd3_models.SDVAE],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
):
r"""
Save models to checkpoint file. Only supports unified checkpoint format.
"""
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("model.diffusion_model.", mmdit.state_dict())
update_sd("first_stage_model.", vae.state_dict())
# do not support unified checkpoint format for now
# if clip_l is not None:
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
# if clip_g is not None:
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
# if t5xxl is not None:
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
save_file(state_dict, ckpt_path, metadata=sai_metadata)
if clip_l is not None:
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
save_file(clip_l.state_dict(), clip_l_path)
if clip_g is not None:
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
save_file(clip_g.state_dict(), clip_g_path)
if t5xxl is not None:
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
t5xxl_state_dict = t5xxl.state_dict()
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
shared_weight = t5xxl_state_dict["shared.weight"]
shared_weight_copy = shared_weight.detach().clone()
t5xxl_state_dict["shared.weight"] = shared_weight_copy
save_file(t5xxl_state_dict, t5xxl_path)
def save_sd3_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
)
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd3_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
)
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
def add_sd3_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--clip_l",
type=str,
required=False,
help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--clip_g",
type=str,
required=False,
help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--t5xxl",
type=str,
required=False,
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--save_clip",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--save_t5xxl",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--t5xxl_device",
type=str,
default=None,
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
)
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=256,
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--clip_l_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--clip_g_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--t5_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--pos_emb_random_crop_rate",
type=float,
default=0.0,
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
)
parser.add_argument(
"--enable_scaled_pos_embed",
action="store_true",
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# Dependencies of Diffusers noise sampler has been removed for clarity in training
parser.add_argument(
"--training_shift",
type=float,
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)
# temporary copied from sd3_minimal_inferece.py
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def max_denoise(model_sampling, sigmas):
max_sigma = float(model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def do_sample(
height: int,
width: int,
seed: int,
cond: Tuple[torch.Tensor, torch.Tensor],
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
guidance_scale: float,
dtype: torch.dtype,
device: str,
):
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
latent = latent.to(dtype).to(device)
# noise = get_noise(seed, latent).to(device)
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = (
torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
.to(latent.dtype)
.to(device)
)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_all_sigmas(model_sampling, steps).to(device)
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
# with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
mmdit.prepare_block_swap_before_forward()
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
mmdit.prepare_block_swap_before_forward()
return x
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
mmdit,
vae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
mmdit = accelerator.unwrap_model(mmdit)
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
mmdit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
mmdit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
mmdit: sd3_models.MMDiT,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
vae: sd3_models.SDVAE,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image
clean_memory_on_device(accelerator.device)
with accelerator.autocast(), torch.no_grad():
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
image = Image.fromarray(decoded_np)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# region Diffusers
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised = sample - model_output * sigma
# 2. Convert to an ODE derivative
derivative = (sample - denoised) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
# endregion
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
t_min = args.min_timestep if args.min_timestep is not None else 0
t_max = args.max_timestep if args.max_timestep is not None else 1000
shift = args.training_shift
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)
indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)
# sigmas according to flowmatching
sigmas = timesteps / 1000
sigmas = sigmas.view(-1, 1, 1, 1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas

View File

@@ -1,302 +0,0 @@
from dataclasses import dataclass
import math
import re
from typing import Dict, List, Optional, Union
import torch
import safetensors
from safetensors.torch import load_file
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models
# TODO move some of functions to model_util.py
from library import sdxl_model_util
# region models
# TODO remove dependency on flux_utils
from library.utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
logger.info(f"Analyzing state dict state...")
# analyze configs
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
for key in list(state_dict.keys()):
m = re_attn.search(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]
# only supports 3-5-large, medium or 3-medium
if qk_norm is not None:
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
else:
model_type = "3-5-medium"
else:
model_type = "3-medium"
params = sd3_models.SD3Params(
patch_size=patch_size,
depth=depth,
num_patches=num_patches,
pos_embed_max_size=pos_embed_max_size,
adm_in_channels=adm_in_channels,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
context_embedder_in_features=context_embedder_in_features,
context_embedder_out_features=context_embedder_out_features,
model_type=model_type,
)
logger.info(f"Analyzed state dict state: {params}")
return params
def load_mmdit(
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
) -> sd3_models.MMDiT:
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load MMDiT
logger.info("Building MMDit")
params = analyze_state_dict_state(mmdit_sd)
with init_empty_weights():
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
logger.info("Loading state dict...")
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
logger.info(f"Loaded MMDiT: {info}")
return mmdit
def load_clip_l(
clip_l_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_l_sd = None
if clip_l_path is None:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_l_path is None:
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
return None
# load clip_l
logger.info("Building CLIP-L")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_l_sd is None:
logger.info(f"Loading state dict from {clip_l_path}")
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if "text_projection.weight" not in clip_l_sd:
logger.info("Adding text_projection.weight to clip_l_sd")
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_clip_g(
clip_g_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_g_sd = None
if state_dict is not None:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_g_path is None:
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
return None
# load clip_g
logger.info("Building CLIP-G")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_g_sd is None:
logger.info(f"Loading state dict from {clip_g_path}")
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-G: {info}")
return clip
def load_t5xxl(
t5xxl_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
t5xxl_sd = None
if state_dict is not None:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
elif t5xxl_path is None:
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
return None
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
def load_vae(
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE(vae_dtype, device)
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
return vae
# endregion
class ModelSamplingDiscreteFlow:
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def __init__(self, shift=1.0):
self.shift = shift
timesteps = 1000
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
# assert max_denoise is False, "max_denoise not implemented"
# max_denoise is always True, I'm not sure why it's there
return sigma * noise + (1.0 - sigma) * latent_image

View File

@@ -13,20 +13,12 @@ from tqdm import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import logging
from PIL import Image
from library import (
sdxl_model_util,
sdxl_train_util,
strategy_base,
strategy_sdxl,
train_util,
sdxl_original_unet,
sdxl_original_control_net,
)
from library import sdxl_model_util, sdxl_train_util, train_util
try:
@@ -545,7 +537,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
vae: AutoencoderKL,
text_encoder: List[CLIPTextModel],
tokenizer: List[CLIPTokenizer],
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
@@ -602,6 +594,74 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if text_pool is not None:
text_pool = text_pool.repeat(1, num_images_per_prompt)
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
if do_classifier_free_guidance:
bs_embed, seq_len, _ = uncond_embeddings.shape
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if uncond_pool is not None:
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
return text_embeddings, text_pool, None, None
def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@@ -732,7 +792,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet: sdxl_original_control_net.SdxlControlNet = None,
controlnet=None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -836,24 +896,32 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
text_embeddings_list = []
text_pool = None
uncond_embeddings_list = []
uncond_pool = None
for i in range(len(self.tokenizers)):
self.tokenizer = self.tokenizers[i]
self.text_encoder = self.text_encoders[i]
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
)
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
if do_classifier_free_guidance:
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, input_ids, weights
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2=i == 1,
)
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
else:
uncond_embeddings = None
uncond_pool = None
text_embeddings_list.append(text_embeddings)
uncond_embeddings_list.append(uncond_embeddings)
if tp1 is not None:
text_pool = tp1
if up1 is not None:
uncond_pool = up1
unet_dtype = self.unet.dtype
dtype = unet_dtype
@@ -902,23 +970,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# create size embs and concat embeddings for SDXL
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
crop_size = torch.zeros_like(orig_size)
target_size = orig_size
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
# make conditionings
text_pool = text_pool.to(device, dtype)
if do_classifier_free_guidance:
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
text_embeddings = torch.cat(text_embeddings_list, dim=2)
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
uncond_pool = uncond_pool.to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
vector_embedding = torch.cat([uncond_vector, cond_vector])
cond_vector = torch.cat([text_pool, embs], dim=1)
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
else:
text_embedding = text_embeddings.to(device, dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1)
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
@@ -926,14 +994,22 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# FIXME SD1 ControlNet is not working
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# predict the noise residual
if controlnet is not None:
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
else:
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
# perform guidance

View File

@@ -8,7 +8,7 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from library.utils import setup_logging
from .utils import setup_logging
setup_logging()
import logging

View File

@@ -1,272 +0,0 @@
# some parts are modified from Diffusers library (Apache License 2.0)
import math
from types import SimpleNamespace
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sdxl_original_unet
from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
class ControlNetConditioningEmbedding(nn.Module):
def __init__(self):
super().__init__()
dims = [16, 32, 96, 256]
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(dims) - 1):
channel_in = dims[i]
channel_out = dims[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
nn.init.zeros_(self.conv_out.weight) # zero module weight
nn.init.zeros_(self.conv_out.bias) # zero module bias
def forward(self, x):
x = self.conv_in(x)
x = F.silu(x)
for block in self.blocks:
x = block(x)
x = F.silu(x)
x = self.conv_out(x)
return x
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
def __init__(self, multiplier: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
self.multiplier = multiplier
# remove unet layers
self.output_blocks = nn.ModuleList([])
del self.out
self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
self.controlnet_down_blocks = nn.ModuleList([])
for dim in dims:
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
unet_sd = unet.state_dict()
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
sd = super().state_dict()
sd.update(unet_sd)
info = super().load_state_dict(sd, strict=True, assign=True)
return info
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
# convert state_dict to SAI format
unet_sd = {}
for k in list(state_dict.keys()):
if not k.startswith("controlnet_"):
unet_sd[k] = state_dict.pop(k)
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
state_dict.update(unet_sd)
super().load_state_dict(state_dict, strict=strict, assign=assign)
def state_dict(self, destination=None, prefix="", keep_vars=False):
# convert state_dict to Diffusers format
state_dict = super().state_dict(destination, prefix, keep_vars)
control_net_sd = {}
for k in list(state_dict.keys()):
if k.startswith("controlnet_"):
control_net_sd[k] = state_dict.pop(k)
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
state_dict.update(control_net_sd)
return state_dict
def forward(
self,
x: torch.Tensor,
timesteps: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
cond_image: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
multiplier = self.multiplier if self.multiplier is not None else 1.0
hs = []
for i, module in enumerate(self.input_blocks):
h = call_module(module, h, emb, context)
if i == 0:
h = self.controlnet_cond_embedding(cond_image) + h
hs.append(self.controlnet_down_blocks[i](h) * multiplier)
h = call_module(self.middle_block, h, emb, context)
h = self.controlnet_mid_block(h) * multiplier
return hs, h
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
"""
This class is for training purpose only.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
h = h + mid_add
for module in self.output_blocks:
resi = hs.pop() + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
if __name__ == "__main__":
import time
logger.info("create unet")
unet = SdxlControlledUNet()
unet.to("cuda", torch.bfloat16)
unet.set_use_sdpa(True)
unet.set_gradient_checkpointing(True)
unet.train()
logger.info("create control_net")
control_net = SdxlControlNet()
control_net.to("cuda")
control_net.set_use_sdpa(True)
control_net.set_gradient_checkpointing(True)
control_net.train()
logger.info("Initialize control_net from unet")
control_net.init_from_unet(unet)
unet.requires_grad_(False)
control_net.requires_grad_(True)
# 使用メモリ量確認用の疑似学習ループ
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# import transformers
# optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
scaler = torch.cuda.amp.GradScaler(enabled=True)
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
txt = torch.randn(batch_size, 77, 2048).cuda()
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
output = unet(x, t, txt, vector, input_resi_add, mid_add)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info("finish training")
sd = control_net.state_dict()
from safetensors.torch import save_file
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")

View File

@@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
from .utils import setup_logging
setup_logging()
import logging
@@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
@@ -1209,8 +1209,6 @@ class InferSdxlUNet2DConditionModel:
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
if mid_add is not None:
h = h + mid_add
for module in _self.output_blocks:
# Deep Shrink
@@ -1219,11 +1217,7 @@ class InferSdxlUNet2DConditionModel:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
resi = hs.pop()
if input_resi_add is not None:
resi = resi + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0

View File

@@ -12,6 +12,7 @@ from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
@@ -326,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
)
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
@@ -363,9 +364,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# )
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
# assert (
# not hasattr(args, "weighted_captions") or not args.weighted_captions
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
@@ -377,6 +378,4 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
def sample_images(*args, **kwargs):
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

View File

@@ -1,911 +0,0 @@
# base class for platform strategies. this file defines the interface for strategies
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from safetensors.torch import safe_open, save_file
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import dataset_metadata_utils, utils
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
if dtype is None:
# all dtypes are acceptable
return get_available_dtypes()
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
compatible_dtypes = [torch.float32]
if dtype.itemsize == 1: # fp8
compatible_dtypes.append(torch.bfloat16)
compatible_dtypes.append(torch.float16)
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
return compatible_dtypes
def get_available_dtypes() -> List[torch.dtype]:
"""
Returns the list of available dtypes for latents caching. Higher precision is preferred.
"""
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
"""
Removes lower precision values from tensor_dict.
"""
available_dtypes = get_available_dtypes()
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
for key_without_dtype in keys_without_dtype:
available_itemsize = None
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
key = key_without_dtype + dtype_suffix
if key in tensor_dict:
if available_itemsize is None:
available_itemsize = dtype.itemsize
elif available_itemsize > dtype.itemsize:
# if higher precision latents are already cached, remove lower precision latents
del tensor_dict[key]
def get_compatible_dtype_keys(
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str]]:
"""
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
If the key is not found, it returns None.
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
:param dict_keys: set of keys in the dictionary
:param keys_without_dtype: list of keys without dtype suffix to check
:param dtype: dtype to check, or None for any dtype
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
"""
compatible_dtypes = get_compatible_dtypes(dtype)
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
available_keys = []
for key_without_dtype in keys_without_dtype:
available_key = None
if key_without_dtype in dict_keys:
available_key = key_without_dtype
else:
for dtype_suffix in dtype_suffixes:
key = key_without_dtype + dtype_suffix
if key in dict_keys:
available_key = key
break
available_keys.append(available_key)
return available_keys
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
_re_attention = re.compile(
r"""\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy
def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
"""
raise NotImplementedError
def _get_weighted_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
max_length includes starting and ending tokens.
"""
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in TokenizeStrategy._re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(text: str, max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
No padding, starting or ending token is included.
"""
truncated = False
texts_and_weights = parse_prompt_attention(text)
tokens = []
weights = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
tokens += token
# copy the weight by length of token
weights += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(tokens) > max_length:
truncated = True
break
# truncate
if len(tokens) > max_length:
truncated = True
tokens = tokens[:max_length]
weights = weights[:max_length]
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
return tokens, weights
if max_length is None:
max_length = tokenizer.model_max_length
tokens, weights = get_prompts_with_weights(text, max_length - 2)
tokens, weights = pad_tokens_and_weights(
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
)
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
def _get_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
"""
if max_length is None:
max_length = tokenizer.model_max_length - 2
if weighted:
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
else:
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
if weighted:
weights = weights.squeeze(0)
new_weights = torch.ones(input_ids.shape)
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
b = i // (tokenizer.model_max_length - 2)
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
weights = new_weights
if weighted:
return input_ids, weights
return input_ids
class TextEncodingStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
return cls._strategy
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:param weights: list of weight tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self,
architecture: str,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool = False,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
"""
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
"""
self._architecture = architecture
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._max_token_length = max_token_length
self._masked = masked
self._is_partial = is_partial
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def max_token_length(self):
return self._max_token_length
@property
def masked(self):
return self._masked
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
suffix_masked = "_m" if self.masked else ""
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
@property
def is_partial(self):
return self._is_partial
@property
def is_weighted(self):
return self._is_weighted
def get_cache_path(self, absolute_path: str) -> str:
return os.path.splitext(absolute_path)[0] + self.cache_suffix
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
raise NotImplementedError
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
"""
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
all dtype tensors are returned, because cache validation is done in advance.
"""
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
dict_keys = f.keys()
results = []
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
for key in compatible_keys:
results.append(f.get_tensor(key) if key is not None else None)
return results
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
raise NotImplementedError
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
"""
masked: may be False even if self.masked is True. It is False for some outputs.
"""
key_suffix = f"_{prompt_id}"
if dtype is not None and dtype.is_floating_point: # float tensor only
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
return key_suffix
def get_compatible_output_keys(
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str], Optional[str]]:
"""
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
"""
key_suffix = self.get_key_suffix(caption_index, None)
keys_without_dtype = [k + key_suffix for k in base_keys]
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
def _default_is_disk_cached_outputs_expected(
self,
cache_path: str,
captions: list[str],
base_keys: list[tuple[str, bool]],
preferred_dtype: Optional[Union[str, torch.dtype]],
):
if not self.cache_to_disk:
return False
if not os.path.exists(cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
keys = f.keys()
metadata = f.metadata()
# check captions in metadata
for i, caption in enumerate(captions):
if metadata.get(f"caption{i+1}") != caption:
return False
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
if any(key is None for key in compatible_keys):
return False
except Exception as e:
logger.error(f"Error loading file: {cache_path}")
raise e
return True
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: list[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
raise NotImplementedError
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
tensor_dict = {}
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["format_version"] = "1.0.0"
metadata[f"caption{caption_index+1}"] = caption
for key, output in zip(keys, outputs):
dtype = output.dtype # long or one of float
key_suffix = self.get_key_suffix(caption_index, dtype)
tensor_dict[key + key_suffix] = output
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(caption_index, None)
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)
class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
self._architecture = architecture
self._latents_stride = latents_stride
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self.load_version_warning_printed = False
self.save_version_warning_printed = False
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def latents_stride(self):
return self._latents_stride
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
return f"_{self.architecture.lower()}.safetensors"
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
return int(w), int(h)
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
def get_latents_cache_path(
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
) -> str:
if cache_dir is not None:
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
else:
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
else:
cache_file_base = absolute_path_or_archive_img_path
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[Union[str, torch.dtype]],
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def get_key_suffix(
self,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
) -> str:
"""
if dtype is None, it returns "_32x64" for example.
"""
if latents_size is not None:
expected_latents_size = latents_size # H, W
else:
# bucket_reso is (W, H)
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
if dtype is None:
dtype_suffix = ""
else:
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
# e.g. "_32x64_float16", HxW, dtype
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
return key_suffix
def get_compatible_latents_keys(
self,
keys: set[str],
dtype: Optional[Union[str, torch.dtype]],
flip_aug: bool,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
) -> list[Optional[str], Optional[str]]:
"""
bucket_reso is (W, H), latents_size is (H, W)
"""
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
keys_without_dtype = ["latents" + key_suffix]
if flip_aug:
keys_without_dtype.append("latents_flipped" + key_suffix)
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
return compatible_keys if flip_aug else compatible_keys[0] + [None]
def _default_is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
latents_cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[Union[str, torch.dtype]],
):
# multi_resolution is always enabled for any strategy
if not self.cache_to_disk:
return False
if not os.path.exists(latents_cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
try:
# safe_open locks the file, so we cannot use it for checking keys
# with safe_open(latents_cache_path, framework="pt") as f:
# keys = f.keys()
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
keys = f.keys()
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
# print(f"alpha_mask not found: {latents_cache_path}")
return False
# preferred_dtype is None if any dtype is acceptable
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
)
if latents_key is None or (flip_aug and flipped_latents_key is None):
# print(f"Precise dtype not found: {latents_cache_path}")
return False
except Exception as e:
logger.error(f"Error loading file: {latents_cache_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List[utils.ImageInfo],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
with torch.no_grad():
latents_tensors = encode_by_vae(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = encode_by_vae(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
if self.cache_to_disk:
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
raise NotImplementedError
def _default_load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
keys = f.keys()
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
latents = f.get_tensor(latents_key)
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
original_size = [int(metadata["width"]), int(metadata["height"])]
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
crop_ltrb = list(map(int, crop_ltrb.split(",")))
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
cache_path: str,
latents_tensor: torch.Tensor,
original_size: Tuple[int, int],
crop_ltrb: List[int],
flipped_latents_tensor: Optional[torch.Tensor] = None,
alpha_mask: Optional[torch.Tensor] = None,
):
dtype = latents_tensor.dtype
latents_size = latents_tensor.shape[1:3] # H, W
tensor_dict = {}
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
# we cannot use safe_open here because it locks the file
# with safe_open(cache_path, framework="pt") as f:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["width"] = f"{original_size[0]}"
metadata["height"] = f"{original_size[1]}"
metadata["format_version"] = "1.0.0"
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
if latents_tensor is not None:
tensor_dict["latents" + key_suffix] = latents_tensor
if flipped_latents_tensor is not None:
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
if alpha_mask is not None:
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)

View File

@@ -1,249 +0,0 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_utils, train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, t5_tokens, t5_attn_mask]
class FluxTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
# clip_l is None when using T5 only
if clip_l is not None and l_tokens is not None:
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
else:
l_pooled = None
# t5xxl is None when using CLIP only
if t5xxl is not None and t5_tokens is not None:
# t5_out is [b, max length, 4096]
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
# if zero_pad_t5_output:
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
KEYS = ["l_pooled", "t5_out", "txt_ids"]
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool,
is_partial: bool = False,
) -> None:
super().__init__(
FluxLatentsCachingStrategy.ARCHITECTURE,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked,
is_partial,
)
self.warn_fp8_weights = False
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
):
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
)[0]
else:
t5_attn_mask = None
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
logger.warning(
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = tokens_and_masks[2].cpu()
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, (info, caption_index, caption) in enumerate(batch):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
if self.cache_to_disk:
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
if self.masked:
outputs += [t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
ARCHITECTURE = "flux"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for FluxTokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = FluxTokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

View File

@@ -1,168 +0,0 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
class SdTokenizeStrategy(TokenizeStrategy):
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
"""
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
if v2:
self.tokenizer = self._load_tokenizer(
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
for t in text:
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
tokens_list.append(tokens)
weights_list.append(weights)
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
# tokens: b,n,77
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
tokens = tokens.to(text_encoder.device)
if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
return [encoder_hidden_states]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
weights = weights_list[0].to(encoder_hidden_states.device)
# apply weights
if weights.shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for i in range(weights.shape[1]):
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
:, i, 1:-1
].unsqueeze(-1)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
ARCHITECTURE_SD = "sd"
ARCHITECTURE_SDXL = "sdxl"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
) -> bool:
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -1,390 +0,0 @@
import os
import glob
import random
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class Sd3TokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
l_attn_mask = l_tokens["attention_mask"]
g_attn_mask = g_tokens["attention_mask"]
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(
self,
apply_lg_attn_mask: Optional[bool] = None,
apply_t5_attn_mask: Optional[bool] = None,
l_dropout_rate: float = 0.0,
g_dropout_rate: float = 0.0,
t5_dropout_rate: float = 0.0,
) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
self.l_dropout_rate = l_dropout_rate
self.g_dropout_rate = g_dropout_rate
self.t5_dropout_rate = t5_dropout_rate
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
clip_l: Optional[CLIPTextModel]
clip_g: Optional[CLIPTextModelWithProjection]
t5xxl: Optional[T5EncoderModel]
if apply_lg_attn_mask is None:
apply_lg_attn_mask = self.apply_lg_attn_mask
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
l_attn_mask = None
g_attn_mask = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
batch_size, l_seq_len = l_tokens.shape
g_seq_len = g_tokens.shape[1]
non_drop_l_indices = []
non_drop_g_indices = []
for i in range(l_tokens.shape[0]):
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if not drop_l:
non_drop_l_indices.append(i)
if not drop_g:
non_drop_g_indices.append(i)
# filter out dropped members
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
l_tokens = l_tokens[non_drop_l_indices]
l_attn_mask = l_attn_mask[non_drop_l_indices]
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
g_tokens = g_tokens[non_drop_g_indices]
g_attn_mask = g_attn_mask[non_drop_g_indices]
# call clip_l for non-dropped members
if len(non_drop_l_indices) > 0:
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
prompt_embeds = clip_l(
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_l_pooled = prompt_embeds[0]
nd_l_out = prompt_embeds.hidden_states[-2]
if len(non_drop_g_indices) > 0:
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
prompt_embeds = clip_g(
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_g_pooled = prompt_embeds[0]
nd_g_out = prompt_embeds.hidden_states[-2]
# fill in the dropped members
if len(non_drop_l_indices) == batch_size:
l_pooled = nd_l_pooled
l_out = nd_l_out
else:
# model output is always float32 because of the models are wrapped with Accelerator
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
if len(non_drop_l_indices) > 0:
l_pooled[non_drop_l_indices] = nd_l_pooled
l_out[non_drop_l_indices] = nd_l_out
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
if len(non_drop_g_indices) == batch_size:
g_pooled = nd_g_pooled
g_out = nd_g_out
else:
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
if len(non_drop_g_indices) > 0:
g_pooled[non_drop_g_indices] = nd_g_pooled
g_out[non_drop_g_indices] = nd_g_out
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is None or t5_tokens is None:
t5_out = None
t5_attn_mask = None
else:
# drop some members of the batch: we do not call t5xxl for dropped members
batch_size, t5_seq_len = t5_tokens.shape
non_drop_t5_indices = []
for i in range(t5_tokens.shape[0]):
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if not drop_t5:
non_drop_t5_indices.append(i)
# filter out dropped members
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
t5_tokens = t5_tokens[non_drop_t5_indices]
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
# call t5xxl for non-dropped members
if len(non_drop_t5_indices) > 0:
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
nd_t5_out, _ = t5xxl(
t5_tokens.to(t5xxl.device),
nd_t5_attn_mask if apply_t5_attn_mask else None,
return_dict=False,
output_hidden_states=True,
)
# fill in the dropped members
if len(non_drop_t5_indices) == batch_size:
t5_out = nd_t5_out
else:
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
if len(non_drop_t5_indices) > 0:
t5_out[non_drop_t5_indices] = nd_t5_out
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
# masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
lg_out: torch.Tensor,
t5_out: torch.Tensor,
lg_pooled: torch.Tensor,
l_attn_mask: torch.Tensor,
g_attn_mask: torch.Tensor,
t5_attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
if lg_out is not None:
for i in range(lg_out.shape[0]):
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
if drop_l:
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
if l_attn_mask is not None:
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
if drop_g:
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
if g_attn_mask is not None:
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
if t5_out is not None:
for i in range(t5_out.shape[0]):
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
if drop_t5:
t5_out[i] = torch.zeros_like(t5_out[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if t5_out is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
KEYS = ["lg_out", "t5_out", "lg_pooled"]
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
max_token_length: int = 256,
masked: bool = False,
) -> None:
"""
apply_lg_attn_mask and apply_t5_attn_mask must be same
"""
super().__init__(
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked=masked,
is_partial=is_partial,
)
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
)
else:
l_attn_mask = g_attn_mask = t5_attn_mask = None
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# always disable dropout during caching
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.masked,
apply_t5_attn_mask=self.masked,
enable_dropout=False,
)
lg_out = lg_out.cpu()
lg_pooled = lg_pooled.cpu()
t5_out = t5_out.cpu()
l_attn_mask = tokens_and_masks[3].cpu()
g_attn_mask = tokens_and_masks[4].cpu()
t5_attn_mask = tokens_and_masks[5].cpu()
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, (info, caption_index, caption) in enumerate(batch):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
if self.cache_to_disk:
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
if self.masked:
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [
lg_out_i,
t5_out_i,
lg_pooled_i,
l_attn_mask_i,
g_attn_mask_i,
t5_attn_mask_i,
]
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
ARCHITECTURE_SD3 = "sd3"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -1,305 +0,0 @@
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library import utils
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
if max_length is None:
self.max_length = self.tokenizer1.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return (
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens1_list, tokens2_list = [], []
weights1_list, weights2_list = [], []
for t in text:
tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
tokens1_list.append(tokens1)
tokens2_list.append(tokens2)
weights1_list.append(weights1)
weights2_list.append(weights2)
return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
torch.stack(weights1_list, dim=0),
torch.stack(weights2_list, dim=0),
]
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def _pool_workaround(
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# Create a mask where the EOS tokens are
eos_token_mask = (input_ids == eos_token_id).int()
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# get hidden states for EOS token
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
]
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output
def _get_hidden_states_sdxl(
self,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
if input_ids1.size()[1] == 1:
max_token_length = None
else:
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
input_ids2 = input_ids2.to(text_encoder2.device)
# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
return hidden_states1, hidden_states2, pool2
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Args:
tokenize_strategy: TokenizeStrategy
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
tokens: List of tokens, for text_encoder1 and text_encoder2
"""
if len(models) == 2:
text_encoder1, text_encoder2 = models
unwrapped_text_encoder2 = None
else:
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
tokens1, tokens2 = tokens
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
)
return [hidden_states1, hidden_states2, pool2]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
# apply weights
if weights_list[0].shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
for i in range(weight.shape[1]):
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
:, i, 1:-1
].unsqueeze(-1)
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
ARCHITECTURE_SDXL = "sdxl"
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
def __init__(
self,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: Optional[int] = None,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
"""
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
max_token_length = max_token_length or 75
super().__init__(
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
is_weighted,
max_token_length=max_token_length,
)
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
# SDXL does not support attn mask
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [caption for _, _, caption in batch]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens_list, weights_list
)
else:
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [tokens1, tokens2]
)
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
for i, (info, caption_index, caption) in enumerate(batch):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
self.save_outputs_to_disk(
info.text_encoder_outputs_cache_path,
caption_index,
caption,
SdxlTextEncoderOutputsCachingStrategy.KEYS,
[hidden_state1_i, hidden_state2_i, pool2_i],
)
else:
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]

File diff suppressed because it is too large Load Diff

View File

@@ -1,85 +1,21 @@
import logging
import sys
import threading
from typing import *
import json
import struct
import torch
import torch.nn as nn
from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.captions: Optional[list[str]] = None
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
self.list_of_tags: Optional[list[str]] = None
self.tags_weights: Optional[list[float]] = None
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.latents_cache_dir: Optional[str] = None
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_cache_path: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
# crop left top right bottom in original pixel size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
def __str__(self) -> str:
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
self.list_of_tags = list_of_tags
def set_fine_tuning_info(
self,
captions: Optional[list[str]],
caption_weights: Optional[list[float]],
list_of_tags: Optional[list[str]],
tags_weights: Optional[list[float]],
image_size: Tuple[int, int],
latents_cache_dir: Optional[str],
):
self.captions = captions
self.caption_weights = caption_weights
self.list_of_tags = list_of_tags
self.tags_weights = tags_weights
self.image_size = image_size
self.latents_cache_dir = latents_cache_dir
# region Logging
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
@@ -146,304 +82,6 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
# endregion
# region PyTorch utils
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
Args:
s: string representation of the dtype
default_dtype: default dtype to return if s is None
Returns:
torch.dtype: the corresponding torch.dtype
Raises:
ValueError: if the dtype is not supported
Examples:
>>> str_to_dtype("float32")
torch.float32
>>> str_to_dtype("fp32")
torch.float32
>>> str_to_dtype("float16")
torch.float16
>>> str_to_dtype("fp16")
torch.float16
>>> str_to_dtype("bfloat16")
torch.bfloat16
>>> str_to_dtype("bf16")
torch.bfloat16
>>> str_to_dtype("fp8")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fn")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fnuz")
torch.float8_e4m3fnuz
>>> str_to_dtype("fp8_e5m2")
torch.float8_e5m2
>>> str_to_dtype("fp8_e5m2fnuz")
torch.float8_e5m2fnuz
"""
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
# get name of the dtype
dtype_name = str(dtype).split(".")[-1]
return dtype_name
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
@@ -463,9 +101,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
# endregion
# TODO make inf_utils.py
# region Gradual Latent hires fix

View File

@@ -1,434 +0,0 @@
# convert key mapping and data format from some LoRA format to another
"""
Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module
alpha is scalar for each LoRA module
0 to 18
lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([])
lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([])
lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4])
lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([])
lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4])
lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([])
lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288])
lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([])
lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4])
lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([])
lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([])
lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4])
lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([])
lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4])
lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([])
lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288])
lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([])
lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4])
0 to 37
lora_unet_single_blocks_0_linear1.alpha torch.Size([])
lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072])
lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4])
lora_unet_single_blocks_0_linear2.alpha torch.Size([])
lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360])
lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4])
lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([])
lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072])
lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4])
"""
"""
ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules.
A is down, B is up. No alpha for each LoRA module.
0 to 18
transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16])
transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288])
transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16])
transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288])
transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16])
transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16])
0 to 37
transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16])
transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16])
transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16])
transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16])
transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16])
transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360])
transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16])
"""
"""
xlabs: Unknown format.
0 to 18
double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072])
double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16])
double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072])
double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16])
double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072])
double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16])
double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072])
double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16])
"""
import argparse
from safetensors.torch import save_file
from safetensors import safe_open
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key):
ait_down_key = ait_key + ".lora_A.weight"
if ait_down_key not in ait_sd:
return
ait_up_key = ait_key + ".lora_B.weight"
down_weight = ait_sd.pop(ait_down_key)
sds_sd[sds_key + ".lora_down.weight"] = down_weight
sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key)
rank = down_weight.shape[0]
sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device)
def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys):
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
if ait_down_keys[0] not in ait_sd:
return
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
down_weights = [ait_sd.pop(k) for k in ait_down_keys]
up_weights = [ait_sd.pop(k) for k in ait_up_keys]
# lora_down is concatenated along dim=0, so rank is multiplied by the number of splits
rank = down_weights[0].shape[0]
num_splits = len(ait_keys)
sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0)
merged_up_weights = torch.zeros(
(sum(w.shape[0] for w in up_weights), rank * num_splits),
dtype=up_weights[0].dtype,
device=up_weights[0].device,
)
i = 0
for j, up_weight in enumerate(up_weights):
merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight
i += up_weight.shape[0]
sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights
# set alpha to new_rank
new_rank = rank * num_splits
sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device)
def convert_ai_toolkit_to_sd_scripts(ait_sd):
sds_sd = {}
for i in range(19):
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0"
)
convert_to_sd_scripts_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out"
)
convert_to_sd_scripts_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear"
)
for i in range(38):
convert_to_sd_scripts_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear"
)
if len(ait_sd) > 0:
logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}")
return sds_sd
def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# print(f"rank: {rank}, alpha: {alpha}, scale: {scale}")
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
# print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}")
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0"
)
convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out"
)
convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear"
)
for i in range(38):
convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear"
)
if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
return ait_sd
def main(args):
# load source safetensors
logger.info(f"Loading source file {args.src_path}")
state_dict = {}
with safe_open(args.src_path, framework="pt") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
logger.info(f"Converting {args.src} to {args.dst} format")
if args.src == "ai-toolkit" and args.dst == "sd-scripts":
state_dict = convert_ai_toolkit_to_sd_scripts(state_dict)
elif args.src == "sd-scripts" and args.dst == "ai-toolkit":
state_dict = convert_sd_scripts_to_ai_toolkit(state_dict)
# eliminate 'shared tensors'
for k in list(state_dict.keys()):
state_dict[k] = state_dict[k].detach().clone()
else:
raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported")
# save destination safetensors
logger.info(f"Saving destination file {args.dst_path}")
save_file(state_dict, args.dst_path, metadata=metadata)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA format")
parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts")
parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts")
parser.add_argument("--src_path", type=str, default=None, help="source path")
parser.add_argument("--dst_path", type=str, default=None, help="destination path")
args = parser.parse_args()
main(args)

View File

@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -1,219 +0,0 @@
# extract approximating LoRA by svd from two FLUX models
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo!
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
from library import flux_utils, sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import MemoryEfficientSafeOpen
from library.utils import setup_logging
from networks import lora_flux
setup_logging()
import logging
logger = logging.getLogger(__name__)
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, state_dict, metadata, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
save_file(state_dict, file_name, metadata=metadata)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
mem_eff_safe_open=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
calc_dtype = torch.float
save_dtype = str_to_dtype(save_precision)
store_device = "cpu"
# open models
lora_weights = {}
if not mem_eff_safe_open:
# use original safetensors.safe_open
open_fn = lambda fn: safe_open(fn, framework="pt")
else:
logger.info("Using memory efficient safe_open")
open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
with open_fn(model_org) as f_org:
# filter keys
keys = []
for key in f_org.keys():
if not ("single_block" in key or "double_block" in key):
continue
if ".bias" in key:
continue
if "norm" in key:
continue
keys.append(key)
with open_fn(model_tuned) as f_tuned:
for key in tqdm(keys):
# get tensors and calculate difference
value_o = f_org.get_tensor(key)
value_t = f_tuned.get_tensor(key)
mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
del value_o, value_t
# extract LoRA weights
if device:
mat = mat.to(device)
out_dim, in_dim = mat.size()[0:2]
rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
U = U.to(store_device, dtype=save_dtype).contiguous()
Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
# print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
lora_weights[key] = (U, Vh)
del mat, U, S, Vh
# make state dict for LoRA
lora_sd = {}
for key, (up_weight, down_weight) in lora_weights.items():
lora_name = key.replace(".weight", "").replace(".", "_")
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
# minimum metadata
net_kwargs = {}
metadata = {
"ss_v2": str(False),
"ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
"ss_network_module": "networks.lora_flux",
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, metadata, save_dtype)
logger.info(f"LoRA weights saved to {save_to}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
)
parser.add_argument(
"--model_org",
type=str,
default=None,
required=True,
help="Original model: safetensors file / 元モデル、safetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--mem_eff_safe_open",
action="store_true",
help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
" / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: safetensors file / 保存先のファイル名、safetensors",
)
parser.add_argument(
"--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4"
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
# parser.add_argument(
# "--min_diff",
# type=float,
# default=0.01,
# help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
# + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
# )
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(**vars(args))

View File

@@ -1,765 +0,0 @@
import argparse
import math
import os
import time
from typing import Any, Dict, Union
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
import lora_flux as lora_flux
from library import sai_model_spec, train_util
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
metadata = train_util.load_metadata_from_safetensors(file_name)
else:
sd = torch.load(file_name, map_location="cpu")
metadata = {}
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False):
if dtype is not None:
logger.info(f"converting to {dtype}...")
for key in tqdm(list(state_dict.keys())):
if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point:
state_dict[key] = state_dict[key].to(dtype)
logger.info(f"saving to: {file_name}")
if mem_eff_save:
mem_eff_save_file(state_dict, file_name, metadata=metadata)
else:
save_file(state_dict, file_name, metadata=metadata)
def merge_to_flux_model(
loading_device,
working_device,
flux_path: str,
clip_l_path: str,
t5xxl_path: str,
models,
ratios,
merge_dtype,
save_dtype,
mem_eff_load_save=False,
):
# create module map without loading state_dict
lora_name_to_module_key = {}
if flux_path is not None:
logger.info(f"loading keys from FLUX.1 model: {flux_path}")
with safe_open(flux_path, framework="pt", device=loading_device) as flux_file:
keys = list(flux_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
lora_name_to_module_key[lora_name] = key
lora_name_to_clip_l_key = {}
if clip_l_path is not None:
logger.info(f"loading keys from clip_l model: {clip_l_path}")
with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file:
keys = list(clip_l_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_")
lora_name_to_clip_l_key[lora_name] = key
lora_name_to_t5xxl_key = {}
if t5xxl_path is not None:
logger.info(f"loading keys from t5xxl model: {t5xxl_path}")
with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file:
keys = list(t5xxl_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_")
lora_name_to_t5xxl_key[lora_name] = key
flux_state_dict = {}
clip_l_state_dict = {}
t5xxl_state_dict = {}
if mem_eff_load_save:
if flux_path is not None:
with MemoryEfficientSafeOpen(flux_path) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
if clip_l_path is not None:
with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file:
for key in tqdm(clip_l_file.keys()):
clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device)
if t5xxl_path is not None:
with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file:
for key in tqdm(t5xxl_file.keys()):
t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device)
else:
if flux_path is not None:
flux_state_dict = load_file(flux_path, device=loading_device)
if clip_l_path is not None:
clip_l_state_dict = load_file(clip_l_path, device=loading_device)
if t5xxl_path is not None:
t5xxl_state_dict = load_file(t5xxl_path, device=loading_device)
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if "lora_down" in key:
lora_name = key[: key.rfind(".lora_down")]
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
if lora_name in lora_name_to_module_key:
module_weight_key = lora_name_to_module_key[lora_name]
state_dict = flux_state_dict
elif lora_name in lora_name_to_clip_l_key:
module_weight_key = lora_name_to_clip_l_key[lora_name]
state_dict = clip_l_state_dict
elif lora_name in lora_name_to_t5xxl_key:
module_weight_key = lora_name_to_t5xxl_key[lora_name]
state_dict = t5xxl_state_dict
else:
logger.warning(
f"no module found for LoRA weight: {key}. Skipping..."
f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。"
)
continue
down_weight = lora_sd.pop(key)
up_weight = lora_sd.pop(up_key)
dim = down_weight.size()[0]
alpha = lora_sd.pop(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = state_dict[module_weight_key]
weight = weight.to(working_device, merge_dtype)
up_weight = up_weight.to(working_device, merge_dtype)
down_weight = down_weight.to(working_device, merge_dtype)
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
del up_weight
del down_weight
del weight
if len(lora_sd) > 0:
logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
return flux_state_dict, clip_l_state_dict, t5xxl_state_dict
def merge_to_flux_model_diffusers(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
):
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
else:
flux_state_dict = load_file(flux_model, device=loading_device)
def create_key_map(n_double_layers, n_single_layers):
key_map = {}
for index in range(n_double_layers):
prefix_from = f"transformer_blocks.{index}"
prefix_to = f"double_blocks.{index}"
for end in ("weight", "bias"):
k = f"{prefix_from}.attn."
qkv_img = f"{prefix_to}.img_attn.qkv.{end}"
qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}"
key_map[f"{k}to_q.{end}"] = qkv_img
key_map[f"{k}to_k.{end}"] = qkv_img
key_map[f"{k}to_v.{end}"] = qkv_img
key_map[f"{k}add_q_proj.{end}"] = qkv_txt
key_map[f"{k}add_k_proj.{end}"] = qkv_txt
key_map[f"{k}add_v_proj.{end}"] = qkv_txt
block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
for k, v in block_map.items():
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
for index in range(n_single_layers):
prefix_from = f"single_transformer_blocks.{index}"
prefix_to = f"single_blocks.{index}"
for end in ("weight", "bias"):
k = f"{prefix_from}.attn."
qkv = f"{prefix_to}.linear1.{end}"
key_map[f"{k}to_q.{end}"] = qkv
key_map[f"{k}to_k.{end}"] = qkv
key_map[f"{k}to_v.{end}"] = qkv
key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv
block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
}
for k, v in block_map.items():
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
# add as-is keys
values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())])
values.sort()
key_map.update({v: v for v in values})
return key_map
key_map = create_key_map(18, 38) # 18 double layers, 38 single layers
def find_matching_key(flux_dict, lora_key):
lora_key = lora_key.replace("diffusion_model.", "")
lora_key = lora_key.replace("transformer.", "")
lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up")
lora_key = lora_key.replace("single_transformer_blocks", "single_blocks")
lora_key = lora_key.replace("transformer_blocks", "double_blocks")
double_block_map = {
"attn.to_out.0": "img_attn.proj",
"norm1.linear": "img_mod.lin",
"norm1_context.linear": "txt_mod.lin",
"attn.to_add_out": "txt_attn.proj",
"ff.net.0.proj": "img_mlp.0",
"ff.net.2": "img_mlp.2",
"ff_context.net.0.proj": "txt_mlp.0",
"ff_context.net.2": "txt_mlp.2",
"attn.norm_q": "img_attn.norm.query_norm",
"attn.norm_k": "img_attn.norm.key_norm",
"attn.norm_added_q": "txt_attn.norm.query_norm",
"attn.norm_added_k": "txt_attn.norm.key_norm",
"attn.to_q": "img_attn.qkv",
"attn.to_k": "img_attn.qkv",
"attn.to_v": "img_attn.qkv",
"attn.add_q_proj": "txt_attn.qkv",
"attn.add_k_proj": "txt_attn.qkv",
"attn.add_v_proj": "txt_attn.qkv",
}
single_block_map = {
"norm.linear": "modulation.lin",
"proj_out": "linear2",
"attn.norm_q": "norm.query_norm",
"attn.norm_k": "norm.key_norm",
"attn.to_q": "linear1",
"attn.to_k": "linear1",
"attn.to_v": "linear1",
"proj_mlp": "linear1",
}
# same key exists in both single_block_map and double_block_map, so we must care about single/double
# print("lora_key before double_block_map", lora_key)
for old, new in double_block_map.items():
if "double" in lora_key:
lora_key = lora_key.replace(old, new)
# print("lora_key before single_block_map", lora_key)
for old, new in single_block_map.items():
if "single" in lora_key:
lora_key = lora_key.replace(old, new)
# print("lora_key after mapping", lora_key)
if lora_key in key_map:
flux_key = key_map[lora_key]
logger.info(f"Found matching key: {flux_key}")
return flux_key
# If not found in key_map, try partial matching
potential_key = lora_key + ".weight"
logger.info(f"Searching for key: {potential_key}")
matches = [k for k in flux_dict.keys() if potential_key in k]
if matches:
logger.info(f"Found matching key: {matches[0]}")
return matches[0]
return None
merged_keys = set()
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info("merging...")
for key in lora_sd.keys():
if "lora_down" in key or "lora_A" in key:
lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")]
up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B")
alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha"
logger.info(f"Processing LoRA key: {lora_name}")
flux_key = find_matching_key(flux_state_dict, lora_name)
if flux_key is None:
logger.warning(f"no module found for LoRA weight: {key}")
continue
logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
weight = flux_state_dict[flux_key]
weight = weight.to(working_device, merge_dtype)
up_weight = up_weight.to(working_device, merge_dtype)
down_weight = down_weight.to(working_device, merge_dtype)
# print(up_weight.size(), down_weight.size(), weight.size())
if lora_name.startswith("transformer."):
if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp
update = ratio * (up_weight @ down_weight) * scale
# print(update.shape)
if "img_attn" in flux_key or "txt_attn" in flux_key:
q, k, v = torch.chunk(weight, 3, dim=0)
if "to_q" in lora_name or "add_q_proj" in lora_name:
q += update.reshape(q.shape)
elif "to_k" in lora_name or "add_k_proj" in lora_name:
k += update.reshape(k.shape)
elif "to_v" in lora_name or "add_v_proj" in lora_name:
v += update.reshape(v.shape)
weight = torch.cat([q, k, v], dim=0)
elif "linear1" in flux_key:
q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0)
mlp = weight[int(update.shape[-1] * 3) :]
# print(q.shape, k.shape, v.shape, mlp.shape)
if "to_q" in lora_name:
q += update.reshape(q.shape)
elif "to_k" in lora_name:
k += update.reshape(k.shape)
elif "to_v" in lora_name:
v += update.reshape(v.shape)
elif "proj_mlp" in lora_name:
mlp += update.reshape(mlp.shape)
weight = torch.cat([q, k, v, mlp], dim=0)
else:
if len(weight.size()) == 2:
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
else:
if len(weight.size()) == 2:
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
flux_state_dict[flux_key] = weight.to(loading_device, save_dtype)
merged_keys.add(flux_key)
del up_weight
del down_weight
del weight
logger.info(f"Merged keys: {sorted(list(merged_keys))}")
return flux_state_dict
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
merged_sd = {}
base_model = None
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if "alpha" in key:
lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[: key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
logger.info("merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。"
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
alphas_list = list(set(base_alphas.values()))
all_same_dims = True
all_same_alphas = True
for dims in dims_list:
if dims != dims_list[0]:
all_same_dims = False
break
for alphas in alphas_list:
if alphas != alphas_list[0]:
all_same_alphas = False
break
# build minimum metadata
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None)
return merged_sd, metadata
def merge(args):
if args.models is None:
args.models = []
if args.ratios is None:
args.ratios = []
assert len(args.models) == len(
args.ratios
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
assert (
args.save_to or args.clip_l_save_to or args.t5xxl_save_to
), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください"
dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to)
if not os.path.exists(dest_dir):
logger.info(f"creating directory: {dest_dir}")
os.makedirs(dest_dir)
if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None:
if not args.diffusers:
assert (args.clip_l is None and args.clip_l_save_to is None) or (
args.clip_l is not None and args.clip_l_save_to is not None
), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください"
assert (args.t5xxl is None and args.t5xxl_save_to is None) or (
args.t5xxl is not None and args.t5xxl_save_to is not None
), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください"
flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model(
args.loading_device,
args.working_device,
args.flux_model,
args.clip_l,
args.t5xxl,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
else:
assert (
args.clip_l is None and args.t5xxl is None
), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません"
flux_state_dict = merge_to_flux_model_diffusers(
args.loading_device,
args.working_device,
args.flux_model,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
clip_l_state_dict = None
t5xxl_state_dict = None
if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0):
sai_metadata = None
else:
merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
)
if flux_state_dict is not None and len(flux_state_dict) > 0:
logger.info(f"saving FLUX model to: {args.save_to}")
save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
if clip_l_state_dict is not None and len(clip_l_state_dict) > 0:
logger.info(f"saving clip_l model to: {args.clip_l_save_to}")
save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save)
if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0:
logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}")
save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save)
else:
flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
logger.info("calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.no_metadata:
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
)
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, flux_state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
help="precision in saving, same to merging if omitted. supported types: "
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
" / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--flux_model",
type=str,
default=None,
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
)
parser.add_argument(
"--clip_l",
type=str,
default=None,
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--t5xxl",
type=str,
default=None,
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
help="use custom memory efficient load and save functions for FLUX.1 model"
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
)
parser.add_argument(
"--loading_device",
type=str,
default="cpu",
help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます",
)
parser.add_argument(
"--working_device",
type=str,
default="cpu",
help="device to work (merge). Merging LoRA models are done on CPU."
+ " / 作業マージするデバイス。LoRAモデルのマージはCPUで行われます。",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--clip_l_save_to",
type=str,
default=None,
help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--t5xxl_save_to",
type=str,
default=None,
help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--models",
type=str,
nargs="*",
help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)
parser.add_argument(
"--diffusers",
action="store_true",
help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
merge(args)

View File

@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

File diff suppressed because it is too large Load Diff

View File

@@ -1,839 +0,0 @@
# temporary minimum implementation of LoRA
# SD3 doesn't have Conv2d, so we ignore it
# TODO commonize with the original/SD3/FLUX implementation
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from transformers import CLIPTextModelWithProjection, T5EncoderModel
import numpy as np
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
from library import sd3_models
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: sd3_models.SDVAE,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
mmdit,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
context_attn_dim = kwargs.get("context_attn_dim", None)
context_mlp_dim = kwargs.get("context_mlp_dim", None)
context_mod_dim = kwargs.get("context_mod_dim", None)
x_attn_dim = kwargs.get("x_attn_dim", None)
x_mlp_dim = kwargs.get("x_mlp_dim", None)
x_mod_dim = kwargs.get("x_mod_dim", None)
if context_attn_dim is not None:
context_attn_dim = int(context_attn_dim)
if context_mlp_dim is not None:
context_mlp_dim = int(context_mlp_dim)
if context_mod_dim is not None:
context_mod_dim = int(context_mod_dim)
if x_attn_dim is not None:
x_attn_dim = int(x_attn_dim)
if x_mlp_dim is not None:
x_mlp_dim = int(x_mlp_dim)
if x_mod_dim is not None:
x_mod_dim = int(x_mod_dim)
type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval?
assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)"
# double/single train blocks
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
"""
Parse a block selection string and return a list of booleans.
Args:
selection (str): A string specifying which blocks to select.
total_blocks (int): The total number of blocks available.
Returns:
List[bool]: A list of booleans indicating which blocks are selected.
"""
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start = int(start)
end = int(end)
assert 0 <= start < total_blocks, f"invalid start index: {start}"
assert 0 <= end < total_blocks, f"invalid end index: {end}"
assert start <= end, f"invalid range: {start}-{end}"
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks, f"invalid index: {index}"
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
train_t5xxl = None
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None:
train_t5xxl = False
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
def __init__(
self,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
unet: sd3_models.MMDiT,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.emb_dims = [0] * 6 # create emb_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
qkv_dim = 0
if self.split_qkv:
logger.info(f"split qkv for LoRA")
qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0)
if train_t5xxl:
logger.info(f"train T5XXL as well")
# create module instances
def create_modules(
is_mmdit: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_SD3
if is_mmdit
else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][
text_encoder_idx
]
)
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if not filter in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if is_mmdit and type_dims is not None:
# type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
identifier = [
("context_block", "attn"),
("context_block", "mlp"),
("context_block", "adaLN_modulation"),
("x_block", "attn"),
("x_block", "mlp"),
("x_block", "adaLN_modulation"),
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name:
# "lora_unet_joint_blocks_0_x_block_attn_proj..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if self.train_block_indices is not None and not self.train_block_indices[block_index]:
dim = 0
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
skipped.append(lora_name)
continue
# qkv split
split_dims = None
if is_mmdit and split_qkv:
if "joint_blocks" in lora_name and "qkv" in lora_name:
split_dims = [qkv_dim // 3] * 3
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
return loras, skipped
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
# create LoRA for U-Net
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE)
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
if self.emb_dims:
for filter, in_dim in zip(
[
"context_embedder",
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"x_embedder",
"y_embedder",
"final_layer_adaLN_modulation",
"final_layer_linear",
],
self.emb_dims,
):
# x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras)
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to split qkv
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)
# split qkv
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
continue
weight = state_dict[key]
lora_name = key.split(".")[0]
if "lora_down" in key and "weight" in key:
# dense weight (rank*3, in_dim)
split_weight = torch.chunk(weight, 3, dim=0)
for i, split_w in enumerate(split_weight):
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
del state_dict[key]
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
elif "lora_up" in key and "weight" in key:
# sparse weight (out_dim=sum(split_dims), rank*3)
rank = weight.size(1) // 3
i = 0
split_dim = weight.shape[0] // 3
for j in range(3):
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank]
i += split_dim
del state_dict[key]
# alpha is unchanged
return super().load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
split_dim, rank = up_weights[0].size()
qkv_dim = split_dim * 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(3):
up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j]
i += split_dim
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict
def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if (
key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5)
):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_SD3):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of three elements
# if float, use the same value for all three
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr, default_lr, default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]]
elif len(text_encoder_lr) == 2:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]]
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# split text encoder loras for te1 and te3
te1_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
]
te2_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
]
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te2_loras) > 0:
logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}")
params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te3_loras) > 0:
logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}")
params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -51,7 +51,7 @@ class OFTModule(torch.nn.Module):
alpha = alpha.detach().numpy()
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
# original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))

View File

@@ -1,482 +0,0 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
import einops
from transformers import CLIPTextModel
import numpy as np
import torch
import torch.nn.functional as F
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
split_dims: Optional[List[int]] = None,
):
"""
dim -> num blocks
alpha -> constraint
split_dims is used to mimic the split qkv of FLUX as same as Diffusers
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.register_buffer("alpha", torch.tensor(alpha))
# No conv2d in FLUX
# if "Linear" in org_module.__class__.__name__:
self.out_dim = org_module.out_features
# elif "Conv" in org_module.__class__.__name__:
# out_dim = org_module.out_channels
if split_dims is None:
split_dims = [self.out_dim]
else:
assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim"
self.split_dims = split_dims
# assert all dim is divisible by num_blocks
for split_dim in self.split_dims:
assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks"
self.constraint = [alpha * split_dim for split_dim in self.split_dims]
self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims]
self.oft_blocks = torch.nn.ParameterList(
[torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size]
)
self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size]
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
if self.I[0].device != self.oft_blocks[0].device:
self.I = [I.to(self.oft_blocks[0].device) for I in self.I]
block_R_weighted_list = []
for i in range(len(self.oft_blocks)):
block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i])
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = self.I[i]
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
block_R_weighted = self.multiplier * (block_R - I) + I
block_R_weighted_list.append(block_R_weighted)
return block_R_weighted_list
def forward(self, x, scale=None):
if self.multiplier == 0.0:
return self.org_forward(x)
org_module = self.org_module[0]
org_dtype = x.dtype
R = self.get_weight()
W = org_module.weight.to(torch.float32)
B = org_module.bias.to(torch.float32)
# split W to match R
results = []
d2 = 0
for i in range(len(R)):
d1 = d2
d2 += self.split_dims[i]
W1 = W[d1:d2]
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i])
RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped)
RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p")
B1 = B[d1:d2]
result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype))
results.append(result)
result = torch.cat(results, dim=-1)
return result
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
split_dims: Optional[List[int]] = None,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None):
# get org weight
org_sd = self.org_module[0].state_dict()
W = org_sd["weight"].to(torch.float32)
R = self.get_weight(multiplier).to(torch.float32)
d2 = 0
W_list = []
for i in range(len(self.oft_blocks)):
d1 = d2
d2 += self.split_dims[i]
W1 = W[d1:d2]
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i])
W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped)
W1 = einops.rearrange(W1, "k m p -> (k m) p")
W_list.append(W1)
W = torch.cat(W_list, dim=-1)
# convert back to original dtype
W = W.to(org_sd["weight"].dtype)
# set weight to org_module
org_sd["weight"] = W
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None: # should be set
logger.info(
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
)
network_alpha = 1e-3
elif network_alpha >= 1:
logger.warning(
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
)
# attn only or all linear (FFN) layers
enable_all_linear = kwargs.get("enable_all_linear", None)
# enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
# if enable_conv is not None:
# enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
elif "qkv" in name:
continue # ignore qkv
else:
if dim is None:
dim = param.size()[0]
if all_linear is None and "_mlp" in name:
all_linear = True
if dim is not None and alpha is not None and all_linear is not None:
break
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"]
OFT_PREFIX_UNET = "oft_unet"
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.train_t5xxl = False # make compatible with LoRA
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
if is_linear:
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# logger.info(oft_name)
if "double" in oft_name and "qkv" in oft_name:
split_dims = [3072] * 3
elif "single" in oft_name and "linear1" in oft_name:
split_dims = [3072] * 3 + [12288]
else:
split_dims = None
oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
logger.info(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# logger.info num of params
num_params = 0
for p in params:
num_params += p.numel()
logger.info(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -1,8 +0,0 @@
[pytest]
minversion = 6.0
testpaths =
tests
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning

View File

@@ -1,4 +1,4 @@
accelerate==0.33.0
accelerate==0.30.0
transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
@@ -9,9 +9,8 @@ pytorch-lightning==1.9.0
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.4
tensorboard
safetensors==0.4.4
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
@@ -20,7 +19,6 @@ voluptuous==0.13.1
huggingface-hub==0.24.5
# for Image utils
imagesize==1.4.1
numpy<=2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -40,7 +38,5 @@ numpy<=2.0
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
# for kohya_ss library
-e .

View File

@@ -1,407 +0,0 @@
# Minimum Inference Code for SD3
import argparse
import datetime
import math
import os
import random
from typing import Optional, Tuple
import numpy as np
import torch
from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3
init_ipex()
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_sd3
from library.utils import load_safetensors
def get_noise(seed, latent, device="cpu"):
# generator = torch.manual_seed(seed)
generator = torch.Generator(device)
generator.manual_seed(seed)
return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device)
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def max_denoise(model_sampling, sigmas):
max_sigma = float(model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def do_sample(
height: int,
width: int,
initial_latent: Optional[torch.Tensor],
seed: int,
cond: Tuple[torch.Tensor, torch.Tensor],
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
cfg_scale: float,
dtype: torch.dtype,
device: str,
):
if initial_latent is None:
# latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
else:
latent = initial_latent
latent = latent.to(dtype).to(device)
noise = get_noise(seed, latent, device)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_sigmas(model_sampling, steps).to(device)
# sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i
# conditioning = fix_cond(conditioning)
# neg_cond = fix_cond(neg_cond)
# extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale}
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * cfg_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
latent = x
latent = vae.process_out(latent)
return latent
def generate_image(
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
clip_l: CLIPTextModelWithProjection,
clip_g: CLIPTextModelWithProjection,
t5xxl: T5EncoderModel,
steps: int,
prompt: str,
seed: int,
target_width: int,
target_height: int,
device: str,
negative_prompt: str,
cfg_scale: float,
):
# prepare embeddings
logger.info("Encoding prompts...")
# TODO support one-by-one offloading
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# attn masks are not used currently
if args.offload:
clip_l.to("cpu")
clip_g.to("cpu")
t5xxl.to("cpu")
# generate image
logger.info("Generating image...")
mmdit.to(device)
latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device)
if args.offload:
mmdit.to("cpu")
# latent to image
vae.to(device)
with torch.no_grad():
image = vae.decode(latent_sampled)
if args.offload:
vae.to("cpu")
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
out_image = Image.fromarray(decoded_np)
# save image
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
out_image.save(output_path)
logger.info(f"Saved image to {output_path}")
if __name__ == "__main__":
target_height = 1024
target_width = 1024
# steps = 50 # 28 # 50
# cfg_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--cfg_scale", type=float, default=5.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument("--output_dir", type=str, default=".")
# parser.add_argument("--do_not_use_t5xxl", action="store_true")
# parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
sd3_dtype = torch.float32
if args.fp16:
sd3_dtype = torch.float16
elif args.bf16:
sd3_dtype = torch.bfloat16
loading_device = "cpu" if args.offload else device
# load state dict
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
# state_dict = load_file(args.ckpt_path)
state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype)
# load text encoders
clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict)
clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict)
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict)
# MMDiT and VAE
vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict)
mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device)
clip_l.to(sd3_dtype)
clip_g.to(sd3_dtype)
t5xxl.to(sd3_dtype)
vae.to(sd3_dtype)
mmdit.to(sd3_dtype)
if not args.offload:
# make sure to move to the device: some tensors are created in the constructor on the CPU
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
vae.to(device)
mmdit.to(device)
clip_l.eval()
clip_g.eval()
t5xxl.eval()
mmdit.eval()
vae.eval()
# load tokenizers
logger.info("Loading tokenizers...")
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
args.steps,
args.prompt,
args.seed,
args.width,
args.height,
device,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = args.width
height = args.height
steps = None
cfg_scale = args.cfg_scale
while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed>"
" --n <negative prompt>, `--n -` for empty negative prompt"
"Options are kept for the next prompt. Current options:"
f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}"
)
prompt = input()
if prompt == "":
break
# parse options
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
steps if steps is not None else args.steps,
prompt,
seed if seed is not None else args.seed,
width,
height,
device,
negative_prompt if negative_prompt is not None else args.negative_prompt,
cfg_scale,
)
logger.info("Done!")

File diff suppressed because it is too large Load Diff

View File

@@ -1,490 +0,0 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
resolutions = train_dataset_group.get_resolutions()
if val_dataset_group is not None:
resolutions = resolutions + val_dataset_group.get_resolutions()
self.resolutions = resolutions
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
state_dict = utils.load_safetensors(
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
self.model_type = mmdit.model_type
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)
if args.fp8_base:
# check dtype of model
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}")
elif mmdit.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 SD3 model")
else:
logger.info(
"Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
mmdit.to(torch.float8_e4m3fn)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = sd3_utils.load_clip_l(
args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_l.eval()
clip_g = sd3_utils.load_clip_g(
args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_g.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
vae = sd3_utils.load_vae(
args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit
def get_tokenize_strategy(self, args):
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5_dropout_rate,
)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.train_clip and not self.train_t5xxl:
return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
else:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_clip, self.train_clip, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
max_token_length=args.t5xxl_max_token_length,
apply_lg_attn_mask=args.apply_lg_attn_mask,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[2].to(accelerator.device) # may be fp8
if text_encoders[2].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[2].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move CLIP-G back to cpu")
text_encoders[1].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[2].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[2].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
sd3_train_utils.sample_images(
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# this scheduler is not used in training, but used to get num_train_timesteps etc.
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return sd3_models.SDVAE.process_in(latents)
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: flux_models.Flux,
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Predict the noise residual
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
if not args.apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
if not args.apply_t5_attn_mask:
t5_attn_mask = None
# call model
with torch.set_grad_enabled(is_train), accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = latents
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
context=context[diff_output_pr_indices],
y=lg_pooled[diff_output_pr_indices],
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices]
# weighting for differential output preservation is not needed because it is already applied
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type)
def update_metadata(self, metadata, args):
metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0 or index == 1: # CLIP-L/CLIP-G
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0 or index == 1: # CLIP-L/CLIP-G
clip_type = "CLIP-L" if index == 0 else "CLIP-G"
logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
mmdit: sd3_models.MMDiT = unet
mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward()
return mmdit
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
sd3_train_utils.add_sd3_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = Sd3NetworkTrainer()
trainer.train(args)

View File

@@ -17,7 +17,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
from library import deepspeed_utils, sdxl_model_util
import library.train_util as train_util
@@ -104,8 +104,8 @@ def train(args):
setup_logging(args, reset=True)
assert (
not args.weighted_captions or not args.cache_text_encoder_outputs
), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません"
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
@@ -124,16 +124,7 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# データセットを準備する
if args.dataset_class is None:
@@ -175,11 +166,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -272,9 +262,8 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -287,9 +276,6 @@ def train(args):
train_text_encoder1 = False
train_text_encoder2 = False
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if args.train_text_encoder:
# TODO each option for two text encoders?
accelerator.print("enable text encoder training")
@@ -321,21 +307,16 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
if not cache_latents:
vae.requires_grad_(False)
@@ -422,11 +403,7 @@ def train(args):
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -620,11 +597,8 @@ def train(args):
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -654,39 +628,57 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states1, encoder_hidden_states2, pool2 = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
input_ids_list,
weights_list,
)
)
else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
[input_ids1, input_ids2],
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoder1.device),
# batch["input_ids2"].to(text_encoder1.device),
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# logger.info("text encoder outputs verified")
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -700,7 +692,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -714,7 +708,6 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
@@ -723,7 +716,9 @@ def train(args):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -739,7 +734,9 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
@@ -772,7 +769,7 @@ def train(args):
global_step,
accelerator.device,
vae,
tokenizers,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
@@ -802,7 +799,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss": current_loss}
if block_lrs is None:
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
@@ -819,7 +816,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -854,7 +851,7 @@ def train(args):
global_step,
accelerator.device,
vae,
tokenizers,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)

View File

@@ -1,723 +0,0 @@
import argparse
import math
import os
import random
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from diffusers import DDPMScheduler
from diffusers.utils.torch_utils import is_compiled_module
from safetensors.torch import load_file
from library import (
deepspeed_utils,
sai_model_spec,
sdxl_model_util,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_snr_weight,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
unet.to(accelerator.device) # reduce main memory usage
# convert U-Net to Controlled U-Net
logger.info("convert U-Net to Controlled U-Net")
unet_sd = unet.state_dict()
with init_empty_weights():
unet = SdxlControlledUNet()
unet.load_state_dict(unet_sd, strict=True, assign=True)
del unet_sd
# make control net
logger.info("make ControlNet")
if args.controlnet_model_name_or_path:
with init_empty_weights():
control_net = SdxlControlNet()
logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}")
filename = args.controlnet_model_name_or_path
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
info = control_net.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"ControlNet loaded from {filename}: {info}")
else:
control_net = SdxlControlNet()
logger.info("initialize ControlNet from U-Net")
info = control_net.init_from_unet(unet)
logger.info(f"ControlNet initialized from U-Net: {info}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if args.xformers:
unet.set_use_memory_efficient_attention(True, False)
control_net.set_use_memory_efficient_attention(True, False)
elif args.sdpa:
unet.set_use_sdpa(True)
control_net.set_use_sdpa(True)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
control_net.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = []
ctrlnet_params = []
unet_params = []
for name, param in control_net.named_parameters():
if name.startswith("controlnet_"):
ctrlnet_params.append(param)
else:
unet_params.append(param)
trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr})
trainable_params.append({"params": unet_params, "lr": args.learning_rate})
all_params = ctrlnet_params + unet_params
logger.info(f"trainable params count: {len(all_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
control_net.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
control_net.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
control_net, optimizer, train_dataloader, lr_scheduler
)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
unet.requires_grad_(False)
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.eval()
control_net.train()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name),
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, model, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet"
state_dict = model.state_dict()
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(ckpt_file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, ckpt_file, sai_metadata)
else:
torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
control_net.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(control_net):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.no_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2]
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
# '-1 to +1' to '0 to 1'
controlnet_image = (controlnet_image + 1) / 2
with accelerator.autocast():
input_resi_add, mid_add = control_net(
noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image
)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = control_net.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, unwrap_model(control_net))
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, unwrap_model(control_net))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
sdxl_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# end of epoch
if is_main_process:
control_net = unwrap_model(control_net)
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, control_net, force_sync_upload=True)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
# train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
# train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--control_net_lr",
type=float,
default=1e-4,
help="learning rate for controlnet modules / controlnetモジュールの学習率",
)
return parser
if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -23,16 +23,7 @@ from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import (
deepspeed_utils,
sai_model_spec,
sdxl_model_util,
sdxl_original_unet,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -88,14 +79,7 @@ def train(args):
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -122,8 +106,8 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -180,34 +164,30 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
# prepare ControlNet-LLLite
@@ -262,11 +242,7 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -314,7 +290,7 @@ def train(args):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter
unet._set_static_graph() # avoid error for multiple use of the parameter
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
@@ -381,9 +357,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
loss_recorder = train_util.LossRecorder()
@@ -435,25 +409,27 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.no_grad():
# Get the text embedding for conditioning
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -467,7 +443,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -486,8 +464,9 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -541,14 +520,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -12,7 +12,6 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -103,7 +102,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -325,9 +324,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
loss_recorder = train_util.LossRecorder()
@@ -409,7 +406,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -429,8 +426,7 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -484,14 +480,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

View File

@@ -1,30 +1,24 @@
import argparse
from typing import List, Optional, Union
import torch
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
@@ -37,8 +31,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(
@@ -55,45 +47,17 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def get_tokenize_strategy(self, args):
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sdxl.SdxlTextEncodingStrategy()
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders + [accelerator.unwrap_model(text_encoders[-1])]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
else:
return None
def is_text_encoder_outputs_cached(self, args):
return args.cache_text_encoder_outputs
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
@@ -106,11 +70,15 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
accelerator.wait_for_everyone()
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
@@ -179,18 +147,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(
self,
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_conds,
batch,
weight_dtype,
indices: Optional[List[int]] = None,
):
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
@@ -204,12 +161,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
if indices is not None and len(indices) > 0:
noisy_latents = noisy_latents[indices]
timesteps = timesteps[indices]
text_embedding = text_embedding[indices]
vector_embedding = vector_embedding[indices]
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred

View File

@@ -1,15 +1,14 @@
import argparse
import os
from typing import Optional, Union
import regex
import torch
from library.device_utils import init_ipex
init_ipex()
from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
@@ -19,13 +18,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(
@@ -44,20 +41,28 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def get_tokenize_strategy(self, args):
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sdxl.SdxlTextEncodingStrategy()
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -76,11 +81,9 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -119,7 +122,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser

View File

@@ -1,41 +0,0 @@
# Tests
## Install
```
pip install pytest
```
## Usage
```
pytest
```
## Contribution
Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests.
Tests are functions starting with `test_` and files with the pattern `test_*.py`.
```
def test_x():
assert 1 == 2, "Invalid test response"
```
## Resources
### pytest
- https://docs.pytest.org/en/stable/index.html
- https://docs.pytest.org/en/stable/how-to/assert.html
- https://docs.pytest.org/en/stable/how-to/doctest.html
### PyTorch testing
- https://circleci.com/blog/testing-pytorch-model-with-pytest/
- https://pytorch.org/docs/stable/testing.html
- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
- https://github.com/huggingface/pytorch-image-models/tree/main/tests
- https://github.com/pytorch/pytorch/tree/main/test

View File

@@ -1,153 +0,0 @@
from unittest.mock import patch
from library.train_util import get_optimizer
from train_network import setup_parser
import torch
from torch.nn import Parameter
# Optimizer libraries
import bitsandbytes as bnb
from lion_pytorch import lion_pytorch
import schedulefree
import dadaptation
import dadaptation.experimental as dadapt_experimental
import prodigyopt
import schedulefree as sf
import transformers
def test_default_get_optimizer():
with patch("sys.argv", [""]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "torch.optim.adamw.AdamW"
assert optimizer_args == ""
assert isinstance(optimizer, torch.optim.AdamW)
def test_get_schedulefree_optimizer():
with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree"
assert optimizer_args == ""
assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree)
def test_all_supported_optimizers():
optimizers = [
{
"name": "bitsandbytes.optim.adamw.AdamW8bit",
"alias": "AdamW8bit",
"instance": bnb.optim.AdamW8bit,
},
{
"name": "lion_pytorch.lion_pytorch.Lion",
"alias": "Lion",
"instance": lion_pytorch.Lion,
},
{
"name": "torch.optim.adamw.AdamW",
"alias": "AdamW",
"instance": torch.optim.AdamW,
},
{
"name": "bitsandbytes.optim.lion.Lion8bit",
"alias": "Lion8bit",
"instance": bnb.optim.Lion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW8bit",
"alias": "PagedAdamW8bit",
"instance": bnb.optim.PagedAdamW8bit,
},
{
"name": "bitsandbytes.optim.lion.PagedLion8bit",
"alias": "PagedLion8bit",
"instance": bnb.optim.PagedLion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW",
"alias": "PagedAdamW",
"instance": bnb.optim.PagedAdamW,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW32bit",
"alias": "PagedAdamW32bit",
"instance": bnb.optim.PagedAdamW32bit,
},
{"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD},
{
"name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint",
"alias": "DAdaptAdamPreprint",
"instance": dadapt_experimental.DAdaptAdamPreprint,
},
{
"name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad",
"alias": "DAdaptAdaGrad",
"instance": dadaptation.DAdaptAdaGrad,
},
{
"name": "dadaptation.dadapt_adan.DAdaptAdan",
"alias": "DAdaptAdan",
"instance": dadaptation.DAdaptAdan,
},
{
"name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP",
"alias": "DAdaptAdanIP",
"instance": dadapt_experimental.DAdaptAdanIP,
},
{
"name": "dadaptation.dadapt_lion.DAdaptLion",
"alias": "DAdaptLion",
"instance": dadaptation.DAdaptLion,
},
{
"name": "dadaptation.dadapt_sgd.DAdaptSGD",
"alias": "DAdaptSGD",
"instance": dadaptation.DAdaptSGD,
},
{
"name": "prodigyopt.prodigy.Prodigy",
"alias": "Prodigy",
"instance": prodigyopt.Prodigy,
},
{
"name": "transformers.optimization.Adafactor",
"alias": "Adafactor",
"instance": transformers.optimization.Adafactor,
},
{
"name": "schedulefree.adamw_schedulefree.AdamWScheduleFree",
"alias": "AdamWScheduleFree",
"instance": sf.AdamWScheduleFree,
},
{
"name": "schedulefree.sgd_schedulefree.SGDScheduleFree",
"alias": "SGDScheduleFree",
"instance": sf.SGDScheduleFree,
},
]
for opt in optimizers:
with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, _, optimizer = get_optimizer(args, [param])
assert optimizer_name == opt.get("name")
instance = opt.get("instance")
assert instance is not None
assert isinstance(optimizer, instance)

View File

@@ -1,17 +0,0 @@
from library.train_util import split_train_val
def test_split_train_val():
paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"]
sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None]
result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234)
assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths
assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes
result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234)
assert result_paths == ["path7"], result_paths
assert result_sizes == [None], result_sizes
if __name__ == "__main__":
test_split_train_val()

View File

@@ -9,7 +9,7 @@ from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
from library import config_util
from library import train_util
from library import sdxl_train_util
from library.config_util import (
@@ -17,74 +17,42 @@ from library.config_util import (
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None:
if is_flux:
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
else:
is_schnell = False
if is_sd:
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
elif is_sdxl:
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
else:
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
# assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
args.cache_latents = True
args.cache_latents_to_disk = True
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
is_sd = not args.sdxl and not args.flux
is_sdxl = args.sdxl
is_flux = args.flux
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
use_user_config = args.dataset_config is not None
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
@@ -115,12 +83,17 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -133,27 +106,72 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# モデルを読み込む
logger.info("load model")
if is_sd:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
elif is_sdxl:
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
if is_sd or is_sdxl:
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
# cache latents with dataset
# TODO use DataLoader to speed up
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
alpha_mask = batch["alpha_mask"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(
image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents to disk.")
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
@@ -163,13 +181,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
"--no_half_vae",
action="store_true",
@@ -178,8 +191,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_existing",
action="store_true",
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser

View File

@@ -9,69 +9,55 @@ from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import (
config_util,
flux_train_utils,
flux_utils,
sdxl_model_util,
strategy_base,
strategy_flux,
strategy_sd,
strategy_sdxl,
)
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
from cache_latents import set_tokenize_strategy
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = True
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
# できるだけ準備はしておくが今のところSDXLのみしか動かない
assert (
args.sdxl
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
is_sd = not args.sdxl and not args.flux
is_sdxl = args.sdxl
is_flux = args.flux
assert (
is_sdxl or is_flux
), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です"
assert (
is_sdxl or args.weighted_captions is None
), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です"
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
use_user_config = args.dataset_config is not None
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
@@ -102,12 +88,15 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -116,71 +105,69 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype)
# モデルを読み込む
logger.info("load model")
if is_sdxl:
_, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)
text_encoder1.to(accelerator.device, weight_dtype)
text_encoder2.to(accelerator.device, weight_dtype)
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
else:
clip_l = flux_utils.load_clip_l(
args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors
)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors)
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
if t5xxl_dtype != t5xxl_dtype:
if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2:
logger.warning(
"The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop."
" / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。"
)
logger.info(f"Casting T5XXL model to {t5xxl_dtype}")
t5xxl.to(t5xxl_dtype)
text_encoders = [clip_l, t5xxl]
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder1]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# build text encoder outputs caching strategy
if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# build text encoding strategy
if is_sdxl:
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
else:
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
# cache text encoder outputs
train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
input_ids2_list = batch["input_ids2_list"]
image_infos = []
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_info.input_ids2 = input_ids2
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching text encoder outputs to disk.")
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
@@ -190,30 +177,13 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
)
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser

View File

@@ -1,149 +0,0 @@
# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model.
# It is based on the implementation by 2kpr. Thanks to 2kpr!
# Major changes:
# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once.
# - Makes reverse map from diffusers map to avoid loading all tensors.
# - Removes dependency on .json file for weights mapping.
# - Adds support for custom memory efficient load and save functions.
# - Supports saving with different precision.
# - Supports .safetensors file as input.
# Copyright 2024 2kpr. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import os
from pathlib import Path
import safetensors
from safetensors.torch import safe_open
import torch
from tqdm import tqdm
from library import flux_utils
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# if diffusers_path is folder, get safetensors file
diffusers_path = Path(args.diffusers_path)
if diffusers_path.is_dir():
diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
flux_path = Path(args.save_to)
if not os.path.exists(flux_path.parent):
os.makedirs(flux_path.parent)
if not diffusers_path.exists():
logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}")
return
mem_eff_flag = args.mem_eff_load_save
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for i in range(3):
# replace 00001 with 0000i
current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}"))
logger.info(f"Loading diffusers file: {current_diffusers_path}")
open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt"))
with open_func(current_diffusers_path) as f:
for diffusers_key in tqdm(f.keys()):
if diffusers_key in diffusers_to_bfl_map:
tensor = f.get_tensor(diffusers_key).to("cpu")
if save_dtype is not None:
tensor = tensor.to(save_dtype)
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
return
# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
# save flux_sd to safetensors file
logger.info(f"Saving Flux safetensors file: {flux_path}")
if mem_eff_flag:
mem_eff_save_file(flux_sd, flux_path)
else:
safetensors.torch.save_file(flux_sd, flux_path)
logger.info("Conversion completed.")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--diffusers_path",
default=None,
type=str,
required=True,
help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file."
" / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス",
)
parser.add_argument(
"--save_to",
default=None,
type=str,
required=True,
help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
help="use custom memory efficient load and save functions for FLUX.1 model"
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
)
parser.add_argument(
"--save_precision",
type=str,
default=None,
help="precision in saving, default is same as loading precision"
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
" / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
convert(args)

View File

@@ -1,669 +0,0 @@
import argparse
import json
import math
import os
import random
import time
from multiprocessing import Value
# from omegaconf import OmegaConf
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
# session_id = random.randint(0, 2**32)
# training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
)
# DiffusersのControlNetが使用するデータを準備する
if args.v2:
unet.config = {
"act_fn": "silu",
"attention_head_dim": [5, 10, 20, 20],
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 1024,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"dual_cross_attention": False,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": [5, 10, 20, 20],
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
"sample_size": 96,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"use_linear_projection": True,
"upcast_attention": True,
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": True,
"class_embed_type": None,
"num_class_embeds": None,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
else:
unet.config = {
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 768,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": 8,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": False,
"class_embed_type": None,
"num_class_embeds": None,
"upcast_attention": False,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
# unet.config = OmegaConf.create(unet.config)
# make unet.config iterable and accessible by attribute
class CustomConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __contains__(self, name):
return name in self.__dict__
unet.config = CustomConfig(**unet.config)
controlnet = ControlNetModel.from_unet(unet)
if args.controlnet_model_name_or_path:
filename = args.controlnet_model_name_or_path
if os.path.isfile(filename):
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
controlnet.load_state_dict(state_dict)
elif os.path.isdir(filename):
controlnet = ControlNetModel.from_pretrained(filename)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
controlnet.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(controlnet.parameters())
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
controlnet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
if args.fused_backward_pass:
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.to(accelerator.device)
text_encoder.to(accelerator.device)
# transform DDP after prepare
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
controlnet.train()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, model, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(ckpt_file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, ckpt_file)
else:
torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(
noise,
latents.device,
args.multires_noise_iterations,
args.multires_noise_discount,
)
# Sample a random timestep for each image
timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states,
down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(
ckpt_name,
accelerator.unwrap_model(controlnet),
)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(controlnet))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
# end of epoch
if is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,4 +1,42 @@
from library.utils import setup_logging
import argparse
import json
import math
import os
import random
import time
from multiprocessing import Value
# from omegaconf import OmegaConf
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
@@ -6,14 +44,601 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util
from train_control_net import setup_parser, train
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
# session_id = random.randint(0, 2**32)
# training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
)
# DiffusersのControlNetが使用するデータを準備する
if args.v2:
unet.config = {
"act_fn": "silu",
"attention_head_dim": [5, 10, 20, 20],
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 1024,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"dual_cross_attention": False,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": [5, 10, 20, 20],
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
"sample_size": 96,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"use_linear_projection": True,
"upcast_attention": True,
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": True,
"class_embed_type": None,
"num_class_embeds": None,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
else:
unet.config = {
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 768,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": 8,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": False,
"class_embed_type": None,
"num_class_embeds": None,
"upcast_attention": False,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
# unet.config = OmegaConf.create(unet.config)
# make unet.config iterable and accessible by attribute
class CustomConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __contains__(self, name):
return name in self.__dict__
unet.config = CustomConfig(**unet.config)
controlnet = ControlNetModel.from_unet(unet)
if args.controlnet_model_name_or_path:
filename = args.controlnet_model_name_or_path
if os.path.isfile(filename):
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
controlnet.load_state_dict(state_dict)
elif os.path.isdir(filename):
controlnet = ControlNetModel.from_pretrained(filename)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(controlnet.parameters())
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
controlnet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.to(accelerator.device)
text_encoder.to(accelerator.device)
# transform DDP after prepare
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
controlnet.train()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, model, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(ckpt_file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, ckpt_file)
else:
torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(
noise,
latents.device,
args.multires_noise_iterations,
args.multires_noise_discount,
)
# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states,
down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(
ckpt_name,
accelerator.unwrap_model(controlnet),
)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(controlnet))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
# end of epoch
if is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
return parser
if __name__ == "__main__":
logger.warning(
"The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead"
" / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。"
)
parser = setup_parser()
args = parser.parse_args()

View File

@@ -11,7 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils, strategy_base
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
@@ -38,7 +38,6 @@ from library.custom_train_functions import (
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
import library.strategy_sd as strategy_sd
setup_logging()
import logging
@@ -59,14 +58,7 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
@@ -88,11 +80,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -156,17 +147,13 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
@@ -199,11 +186,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -308,19 +292,10 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"dreambooth" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -357,21 +332,23 @@ def train(args):
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -383,8 +360,7 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -419,7 +395,7 @@ def train(args):
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
@@ -444,7 +420,7 @@ def train(args):
)
current_loss = loss.detach().item()
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
@@ -457,7 +433,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -483,9 +459,7 @@ def train(args):
vae,
)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,6 @@ import argparse
import math
import os
from multiprocessing import Value
from typing import Any, List, Optional, Union
import toml
from tqdm import tqdm
@@ -16,7 +15,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
from library import deepspeed_utils, model_util
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -99,46 +98,33 @@ class TextualInversionTrainer:
self.vae_scale_factor = 0.18215
self.is_sdxl = False
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(self, args, train_dataset_group):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
pass
def get_text_encoding_strategy(self, args):
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]:
return text_encoders
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
with torch.enable_grad():
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
return encoder_hidden_states
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -196,13 +182,8 @@ class TextualInversionTrainer:
if args.seed is not None:
set_seed(args.seed)
tokenize_strategy = self.get_tokenize_strategy(args)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = self.get_latents_caching_strategy(args)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -213,7 +194,14 @@ class TextualInversionTrainer:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
accelerator.print(
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
+ "accelerateでは複数のモデルテキストエンコーダーのgradient_accumulation_stepsはサポートされていないようです"
)
# Convert the init_word to token_id
init_token_ids_list = []
@@ -322,13 +310,12 @@ class TextualInversionTrainer:
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
self.assert_extra_args(args, train_dataset_group, val_dataset_group)
self.assert_extra_args(args, train_dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -381,10 +368,11 @@ class TextualInversionTrainer:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -399,11 +387,7 @@ class TextualInversionTrainer:
trainable_params += text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -431,8 +415,20 @@ class TextualInversionTrainer:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
if len(text_encoders) == 1:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
index_no_updates_list = []
orig_embeds_params_list = []
@@ -460,9 +456,6 @@ class TextualInversionTrainer:
else:
unet.eval()
text_encoding_strategy = self.get_text_encoding_strategy(args)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
@@ -517,9 +510,7 @@ class TextualInversionTrainer:
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
# function for saving/removing
@@ -549,14 +540,11 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizers,
text_encoders,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
for epoch in range(num_train_epochs):
@@ -580,16 +568,11 @@ class TextualInversionTrainer:
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
@@ -605,8 +588,7 @@ class TextualInversionTrainer:
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -657,8 +639,8 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizers,
text_encoders,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
@@ -690,7 +672,7 @@ class TextualInversionTrainer:
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -708,7 +690,7 @@ class TextualInversionTrainer:
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
@@ -740,12 +722,11 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizers,
text_encoders,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
accelerator.log({})
# end of epoch

View File

@@ -239,7 +239,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -407,9 +407,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
# function for saving/removing
@@ -463,7 +461,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -475,8 +473,7 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -541,7 +538,7 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -559,7 +556,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)