Merge pull request #1719 from kohya-ss/sd3_5_support

SD3.5 Large support
This commit is contained in:
Kohya S.
2024-11-01 21:55:48 +09:00
committed by GitHub
19 changed files with 3368 additions and 2316 deletions

205
README.md
View File

@@ -1,6 +1,6 @@
This repository contains training, generation and utility scripts for Stable Diffusion.
## FLUX.1 training (WIP)
## FLUX.1 and SD3 training (WIP)
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
@@ -9,8 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)
### Recent Updates
Oct 31, 2024:
- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details.
Oct 19, 2024:
- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature.
@@ -139,7 +146,7 @@ Sep 1, 2024:
Aug 29, 2024:
Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated.
### Contents
## FLUX.1 training
- [FLUX.1 LoRA training](#flux1-lora-training)
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
@@ -586,53 +593,177 @@ python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_fol
## SD3 training
SD3 training is done with `sd3_train.py`.
SD3.5L/M training is now available.
__Sep 1, 2024__:
- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds!
### SD3 LoRA training
__Jul 27, 2024__:
- Latents and text encoder outputs caching mechanism is refactored significantly.
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
- With this change, dataset initialization is significantly faster, especially for large datasets.
The script is `sd3_train_network.py`. See `--help` for options.
- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures.
SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format.
- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training.
Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L).
---
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py
--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml
--output_dir path/to/output/dir --output_name sd3-lora-name
```
(The command is multi-line for readability. Please combine it into one line.)
`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below:
`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
t5xxl works with `fp16` now.
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
`text_encoder_batch_size` is added experimentally for caching faster.
```toml
learning_rate = 1e-6 # seems to depend on the batch size
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
vae_batch_size = 1
text_encoder_batch_size = 4
cache_latents = true
cache_latents_to_disk = true
```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```
__2024/7/27:__
`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training.
Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。
The trained LoRA model can be used with ComfyUI.
SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。
#### Key Options for SD3 LoRA training
Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training.
- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--clip_g` is the path to the CLIP-G model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model.
- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0.
- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training.
- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below.
Other options are described below.
#### Key Features for SD3 LoRA training
1. CLIP-L, G and T5XXL LoRA Support:
- SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time.
- T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
- The trained LoRA can be used with ComfyUI.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|MMDiT|`--network_train_unet_only`|-|o|
|MMDiT + CLIP-L + CLIP-G|-|-|o (*2)|
|MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L and G LoRA training.
- *3: Not tested yet.
2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL.
- When specifying this option, the `--fp8_base` option is automatically enabled.
3. Split Q/K/V Projection Layers (Experimental):
- Same as FLUX.1.
4. CLIP-L/G and T5 Attention Mask Application:
- This function is planned to be implemented in the future.
5. Multi-resolution Training Support:
- Only for SD3.5M.
- Same as FLUX.1 for data preparation.
- If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M.
Technical details of multi-resolution training for SD3.5M:
The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`.
This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf!
#### Specify rank for each layer in SD3 LoRA
You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.
When network_args is not specified, the default value (`network_dim`) is applied, same as before.
|network_args|target layer|
|---|---|
|context_attn_dim|attn in context_block|
|context_mlp_dim|mlp in context_block|
|context_mod_dim|adaLN_modulation in context_block|
|x_attn_dim|attn in x_block|
|x_mlp_dim|mlp in x_block|
|x_mod_dim|adaLN_modulation in x_block|
`"verbose=True"` is also available for debugging. It shows the rank of each layer.
example:
```
--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True"
```
You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list.
example:
```
--network_args "emb_dims=[2,3,4,5,6,7]"
```
Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`.
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`.
#### Specify blocks to train in SD3 LoRA training
You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`.
The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks.
example:
```
--network_args "train_block_indices=1,2,6-8"
```
### Inference for SD3 with LoRA model
The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options.
### SD3 fine-tuning
Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following:
- `--clip_g` is also available for SD3 fine-tuning.
- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning.
- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model.
- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__
- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning.
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training.
- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning.
- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training.
- CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option.
- If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only.
- The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available.
### Extract LoRA from SD3 Models
Not available yet.
### Convert SD3 LoRA
Not available yet.
### Merge LoRA to SD3 checkpoint
Not available yet.
---

View File

@@ -29,7 +29,7 @@ init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
@@ -241,7 +241,7 @@ def train(args):
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:

View File

@@ -231,7 +231,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = sd3_train_utils.load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
@@ -363,7 +363,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)

View File

@@ -15,7 +15,6 @@ from PIL import Image
from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util
from library.sd3_train_utils import load_prompts
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -70,7 +69,7 @@ def sample_images(
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)

View File

@@ -10,40 +10,21 @@ from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from library import flux_models
from library.utils import setup_logging, MemoryEfficientSafeOpen
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
# temporary copy from sd3_utils TODO refactor
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
return load_file(path, device=device)
except:
return load_file(path) # prevent device invalid Error
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
@@ -172,8 +153,14 @@ def load_ae(
return ae
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
logger.info("Building CLIP")
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> CLIPTextModel:
logger.info("Building CLIP-L")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
"architectures": ["CLIPModel"],
@@ -266,15 +253,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
with init_empty_weights():
clip = CLIPTextModel._from_config(config)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP: {info}")
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_t5xxl(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
@@ -314,8 +308,11 @@ def load_t5xxl(
with init_empty_weights():
t5xxl = T5EncoderModel._from_config(config)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl

View File

@@ -57,8 +57,8 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3-medium"
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
@@ -140,10 +140,7 @@ def build_metadata(
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
if sd3 == "m":
arch = ARCH_SD3_M
else:
arch = ARCH_SD3_UNKNOWN
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV

File diff suppressed because it is too large Load Diff

View File

@@ -11,8 +11,8 @@ from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_models, sd3_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -28,60 +28,16 @@ import logging
logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision
def load_target_model(
model_type: str,
args: argparse.Namespace,
state_dict: dict,
accelerator: Accelerator,
attn_mode: str,
model_dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> Union[
sd3_models.MMDiT,
Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE,
]:
loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
if model_type == "mmdit":
model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
elif model_type == "clip_l":
model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
elif model_type == "clip_g":
model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
elif model_type == "t5xxl":
model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
elif model_type == "vae":
model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
else:
raise ValueError(f"Unknown model type: {model_type}")
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram:
model = model.to(accelerator.device)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return model
from library import sd3_models, sd3_utils, strategy_base, train_util
def save_models(
ckpt_path: str,
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel],
mmdit: Optional[sd3_models.MMDiT],
vae: Optional[sd3_models.SDVAE],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
):
@@ -101,24 +57,42 @@ def save_models(
update_sd("model.diffusion_model.", mmdit.state_dict())
update_sd("first_stage_model.", vae.state_dict())
if clip_l is not None:
update_sd("text_encoders.clip_l.", clip_l.state_dict())
if clip_g is not None:
update_sd("text_encoders.clip_g.", clip_g.state_dict())
if t5xxl is not None:
update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
# do not support unified checkpoint format for now
# if clip_l is not None:
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
# if clip_g is not None:
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
# if t5xxl is not None:
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
save_file(state_dict, ckpt_path, metadata=sai_metadata)
if clip_l is not None:
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
save_file(clip_l.state_dict(), clip_l_path)
if clip_g is not None:
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
save_file(clip_g.state_dict(), clip_g_path)
if t5xxl is not None:
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
t5xxl_state_dict = t5xxl.state_dict()
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
shared_weight = t5xxl_state_dict["shared.weight"]
shared_weight_copy = shared_weight.detach().clone()
t5xxl_state_dict["shared.weight"] = shared_weight_copy
save_file(t5xxl_state_dict, t5xxl_path)
def save_sd3_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
@@ -141,9 +115,9 @@ def save_sd3_model_on_epoch_end_or_stepwise(
epoch: int,
num_train_epochs: int,
global_step: int,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
@@ -208,23 +182,75 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する"
"--save_clip",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する"
"--save_t5xxl",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--t5xxl_device",
type=str,
default=None,
help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
)
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=256,
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--clip_l_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--clip_g_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--t5_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--pos_emb_random_crop_rate",
type=float,
default=0.0,
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
)
parser.add_argument(
"--enable_scaled_pos_embed",
action="store_true",
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# copy from Diffusers
@@ -233,16 +259,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd",
)
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
)
@@ -283,7 +318,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# temporary copied from sd3_minimal_inferece.py
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
@@ -319,6 +354,8 @@ def do_sample(
# noise = get_noise(seed, latent).to(device)
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = (
torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
.to(latent.dtype)
@@ -327,7 +364,7 @@ def do_sample(
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_sigmas(model_sampling, steps).to(device)
sigmas = get_all_sigmas(model_sampling, steps).to(device)
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
@@ -337,71 +374,42 @@ def do_sample(
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
# with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
mmdit.prepare_block_swap_before_forward()
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
# Euler method
x = x + d * dt
x = x.to(dtype)
mmdit.prepare_block_swap_before_forward()
return x
def load_prompts(prompt_file: str) -> List[Dict]:
# read prompts
if prompt_file.endswith(".txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif prompt_file.endswith(".toml"):
with open(prompt_file, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif prompt_file.endswith(".json"):
with open(prompt_file, "r", encoding="utf-8") as f:
prompts = json.load(f)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
from library.train_util import line_to_prompt_dict
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
return prompts
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
@@ -429,7 +437,7 @@ def sample_images(
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
@@ -437,10 +445,10 @@ def sample_images(
# unwrap unet and text_encoder(s)
mmdit = accelerator.unwrap_model(mmdit)
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
@@ -453,12 +461,9 @@ def sample_images(
except Exception:
pass
org_vae_device = vae.device # will be on cpu
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
@@ -501,8 +506,6 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
@@ -510,7 +513,7 @@ def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
mmdit: sd3_models.MMDiT,
text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]],
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
vae: sd3_models.SDVAE,
save_dir,
prompt_dict,
@@ -562,32 +565,49 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
te_outputs = sample_prompts_te_outputs[prompt]
else:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt)
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
lg_out, t5_out, pooled = te_outputs
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts
if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs:
neg_te_outputs = sample_prompts_te_outputs[negative_prompt]
else:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt)
neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
lg_out, t5_out, pooled = neg_te_outputs
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
clean_memory_on_device(accelerator.device)
with accelerator.autocast(), torch.no_grad():
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
# latent to image
with torch.no_grad():
image = vae.decode(latents)
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
@@ -609,14 +629,9 @@ def sample_image_inference(
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# region Diffusers
@@ -886,4 +901,78 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas
# endregion

View File

@@ -1,9 +1,12 @@
from dataclasses import dataclass
import math
from typing import Dict, Optional, Union
import re
from typing import Dict, List, Optional, Union
import torch
import safetensors
from safetensors.torch import load_file
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
from .utils import setup_logging
@@ -19,18 +22,62 @@ from library import sdxl_model_util
# region models
# TODO remove dependency on flux_utils
from library.utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
logger.info(f"Analyzing state dict state...")
# analyze configs
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
for key in list(state_dict.keys()):
m = re_attn.search(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]
# only supports 3-5-large, medium or 3-medium
if qk_norm is not None:
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
else:
model_type = "3-5-medium"
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
model_type = "3-medium"
params = sd3_models.SD3Params(
patch_size=patch_size,
depth=depth,
num_patches=num_patches,
pos_embed_max_size=pos_embed_max_size,
adm_in_channels=adm_in_channels,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
context_embedder_in_features=context_embedder_in_features,
context_embedder_out_features=context_embedder_out_features,
model_type=model_type,
)
logger.info(f"Analyzed state dict state: {params}")
return params
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
def load_mmdit(
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
) -> sd3_models.MMDiT:
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
@@ -40,30 +87,25 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
# load MMDiT
logger.info("Building MMDit")
params = analyze_state_dict_state(mmdit_sd)
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
logger.info(f"Loaded MMDiT: {info}")
return mmdit
def load_clip_l(
state_dict: Dict,
clip_l_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if clip_l_path is None:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
@@ -72,34 +114,58 @@ def load_clip_l(
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_l_path is None:
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
return None
# load clip_l
logger.info("Building CLIP-L")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
return clip_l
logger.info(f"Loading state dict from {clip_l_path}")
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if "text_projection.weight" not in clip_l_sd:
logger.info("Adding text_projection.weight to clip_l_sd")
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_clip_g(
state_dict: Dict,
clip_g_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if state_dict is not None:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
@@ -108,34 +174,53 @@ def load_clip_g(
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_g_path is None:
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
return None
# load clip_g
logger.info("Building CLIP-G")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
return clip_g
logger.info(f"Loading state dict from {clip_g_path}")
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-G: {info}")
return clip
def load_t5xxl(
state_dict: Dict,
t5xxl_path: Optional[str],
attn_mode: str,
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if state_dict is not None:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
@@ -144,29 +229,19 @@ def load_t5xxl(
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
elif t5xxl_path is None:
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
return None
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
t5xxl.to(dtype=dtype)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
return t5xxl
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
def load_vae(
state_dict: Dict,
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
vae_sd = {}
if vae_path:
@@ -181,299 +256,15 @@ def load_vae(
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE()
vae = sd3_models.SDVAE(vae_dtype, device)
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
return vae
def load_models(
ckpt_path: str,
clip_l_path: str,
clip_g_path: str,
t5xxl_path: str,
vae_path: str,
attn_mode: str,
device: Union[str, torch.device],
weight_dtype: Optional[Union[str, torch.dtype]] = None,
disable_mmap: bool = False,
clip_dtype: Optional[Union[str, torch.dtype]] = None,
t5xxl_device: Optional[Union[str, torch.device]] = None,
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
vae_dtype: Optional[Union[str, torch.dtype]] = None,
):
"""
Load SD3 models from checkpoint files.
Args:
ckpt_path: Path to the SD3 checkpoint file.
clip_l_path: Path to the clip_l checkpoint file.
clip_g_path: Path to the clip_g checkpoint file.
t5xxl_path: Path to the t5xxl checkpoint file.
vae_path: Path to the VAE checkpoint file.
attn_mode: Attention mode for MMDiT model.
device: Device for MMDiT model.
weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different.
disable_mmap: Disable memory mapping when loading state dict.
clip_dtype: Dtype for Clip models, or None to use default dtype.
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
vae_dtype: Dtype for VAE model, or None to use default dtype.
Returns:
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
"""
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
# Therefore, we need clip_dtype and t5xxl_dtype.
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
t5xxl_device = t5xxl_device or device
clip_dtype = clip_dtype or weight_dtype or torch.float32
t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32
vae_dtype = vae_dtype or weight_dtype or torch.float32
logger.info(f"Loading SD3 models from {ckpt_path}...")
state_dict = load_state_dict(ckpt_path)
# load clip_l
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_state_dict(clip_l_path)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
# load clip_g
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_state_dict(clip_g_path)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
# load t5xxl
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
# MMDiT and VAE
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_state_dict(vae_path)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
else:
state_dict.pop(k) # remove other keys
# load MMDiT
logger.info("Building MMDit")
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
logger.info(f"Loaded MMDiT: {info}")
# load ClipG and ClipL
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
# load T5XXL
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
# load VAE
logger.info("Building VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
return mmdit, clip_l, clip_g, t5xxl, vae
# endregion
# region utils
def get_cond(
prompt: str,
tokenizer: sd3_models.SD3Tokenizer,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
print(t5_tokens)
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
def get_cond_from_tokens(
l_tokens,
g_tokens,
t5_tokens,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
l_out, l_pooled = clip_l.encode_token_weights(l_tokens)
g_out, g_pooled = clip_g.encode_token_weights(g_tokens)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if device is not None:
lg_out = lg_out.to(device=device)
l_pooled = l_pooled.to(device=device)
g_pooled = g_pooled.to(device=device)
if dtype is not None:
lg_out = lg_out.to(dtype=dtype)
l_pooled = l_pooled.to(dtype=dtype)
g_pooled = g_pooled.to(dtype=dtype)
# t5xxl may be in another device (eg. cpu)
if t5_tokens is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
else:
t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
if device is not None:
t5_out = t5_out.to(device=device)
if dtype is not None:
t5_out = t5_out.to(dtype=dtype)
# return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1)
# used if other sd3 models is available
r"""
def get_sd3_configs(state_dict: Dict):
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
# prefix = "model.diffusion_model."
prefix = ""
patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2]
depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[prefix + "pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[prefix + "context_embedder.weight"].shape
context_embedder_config = {
"target": "torch.nn.Linear",
"params": {"in_features": context_shape[1], "out_features": context_shape[0]},
}
return {
"patch_size": patch_size,
"depth": depth,
"num_patches": num_patches,
"pos_embed_max_size": pos_embed_max_size,
"adm_in_channels": adm_in_channels,
"context_embedder": context_embedder_config,
}
def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"):
""
Doesn't load state dict.
""
sd3_configs = get_sd3_configs(state_dict)
mmdit = sd3_models.MMDiT(
input_size=None,
pos_embed_max_size=sd3_configs["pos_embed_max_size"],
patch_size=sd3_configs["patch_size"],
in_channels=16,
adm_in_channels=sd3_configs["adm_in_channels"],
depth=sd3_configs["depth"],
mlp_ratio=4,
qk_norm=None,
num_patches=sd3_configs["num_patches"],
context_size=4096,
attn_mode=attn_mode,
)
return mmdit
"""
class ModelSamplingDiscreteFlow:
@@ -509,6 +300,3 @@ class ModelSamplingDiscreteFlow:
# assert max_denoise is False, "max_denoise not implemented"
# max_denoise is always True, I'm not sure why it's there
return sigma * noise + (1.0 - sigma) * latent_image
# endregion

View File

@@ -518,7 +518,7 @@ class LatentsCachingStrategy:
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL/SD3.0
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)

View File

@@ -190,6 +190,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
@@ -211,7 +212,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
@@ -225,7 +226,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:

View File

@@ -1,9 +1,10 @@
import os
import glob
import random
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
@@ -48,45 +49,200 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def __init__(
self,
apply_lg_attn_mask: Optional[bool] = None,
apply_t5_attn_mask: Optional[bool] = None,
l_dropout_rate: float = 0.0,
g_dropout_rate: float = 0.0,
t5_dropout_rate: float = 0.0,
) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
self.l_dropout_rate = l_dropout_rate
self.g_dropout_rate = g_dropout_rate
self.t5_dropout_rate = t5_dropout_rate
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
clip_l: Optional[CLIPTextModel]
clip_g: Optional[CLIPTextModelWithProjection]
t5xxl: Optional[T5EncoderModel]
l_tokens, g_tokens, t5_tokens = tokens[:3]
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
if l_tokens is None:
if apply_lg_attn_mask is None:
apply_lg_attn_mask = self.apply_lg_attn_mask
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
l_attn_mask = None
g_attn_mask = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
if apply_lg_attn_mask:
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
batch_size, l_seq_len = l_tokens.shape
g_seq_len = g_tokens.shape[1]
non_drop_l_indices = []
non_drop_g_indices = []
for i in range(l_tokens.shape[0]):
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if not drop_l:
non_drop_l_indices.append(i)
if not drop_g:
non_drop_g_indices.append(i)
# filter out dropped members
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
l_tokens = l_tokens[non_drop_l_indices]
l_attn_mask = l_attn_mask[non_drop_l_indices]
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
g_tokens = g_tokens[non_drop_g_indices]
g_attn_mask = g_attn_mask[non_drop_g_indices]
# call clip_l for non-dropped members
if len(non_drop_l_indices) > 0:
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
prompt_embeds = clip_l(
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_l_pooled = prompt_embeds[0]
nd_l_out = prompt_embeds.hidden_states[-2]
if len(non_drop_g_indices) > 0:
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
prompt_embeds = clip_g(
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_g_pooled = prompt_embeds[0]
nd_g_out = prompt_embeds.hidden_states[-2]
# fill in the dropped members
if len(non_drop_l_indices) == batch_size:
l_pooled = nd_l_pooled
l_out = nd_l_out
else:
# model output is always float32 because of the models are wrapped with Accelerator
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
if len(non_drop_l_indices) > 0:
l_pooled[non_drop_l_indices] = nd_l_pooled
l_out[non_drop_l_indices] = nd_l_out
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
if len(non_drop_g_indices) == batch_size:
g_pooled = nd_g_pooled
g_out = nd_g_out
else:
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
if len(non_drop_g_indices) > 0:
g_pooled[non_drop_g_indices] = nd_g_pooled
g_out[non_drop_g_indices] = nd_g_out
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
else:
if t5xxl is None or t5_tokens is None:
t5_out = None
t5_attn_mask = None
else:
# drop some members of the batch: we do not call t5xxl for dropped members
batch_size, t5_seq_len = t5_tokens.shape
non_drop_t5_indices = []
for i in range(t5_tokens.shape[0]):
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if not drop_t5:
non_drop_t5_indices.append(i)
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
return [lg_out, t5_out, lg_pooled]
# filter out dropped members
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
t5_tokens = t5_tokens[non_drop_t5_indices]
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
# call t5xxl for non-dropped members
if len(non_drop_t5_indices) > 0:
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
nd_t5_out, _ = t5xxl(
t5_tokens.to(t5xxl.device),
nd_t5_attn_mask if apply_t5_attn_mask else None,
return_dict=False,
output_hidden_states=True,
)
# fill in the dropped members
if len(non_drop_t5_indices) == batch_size:
t5_out = nd_t5_out
else:
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
if len(non_drop_t5_indices) > 0:
t5_out[non_drop_t5_indices] = nd_t5_out
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
# masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
lg_out: torch.Tensor,
t5_out: torch.Tensor,
lg_pooled: torch.Tensor,
l_attn_mask: torch.Tensor,
g_attn_mask: torch.Tensor,
t5_attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
if lg_out is not None:
for i in range(lg_out.shape[0]):
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
if drop_l:
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
if l_attn_mask is not None:
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
if drop_g:
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
if g_attn_mask is not None:
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
if t5_out is not None:
for i in range(t5_out.shape[0]):
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
if drop_t5:
t5_out[i] = torch.zeros_like(t5_out[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
@@ -132,39 +288,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
# t5xxl is optional
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
l_out = lg_out[..., :768]
g_out = lg_out[..., 768:] # 1280
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
return np.concatenate([l_out, g_out], axis=-1)
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None
t5_out = data["t5_out"]
if self.apply_lg_attn_mask:
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
if self.apply_t5_attn_mask and t5_out is not None:
t5_attn_mask = data["t5_attn_mask"]
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
return [lg_out, t5_out, lg_pooled]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
@@ -174,46 +329,56 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
# always disable dropout during caching
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out is not None and t5_out.dtype == torch.bfloat16:
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
if t5_out is not None:
t5_out = t5_out.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i] if t5_out is not None else None
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
clip_l_attn_mask=clip_l_attn_mask_i,
clip_g_attn_mask=clip_g_attn_mask_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
**kwargs,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
@@ -234,7 +399,12 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -242,45 +412,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for Sd3TokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = Sd3TokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

View File

@@ -1082,6 +1082,10 @@ class BaseDataset(torch.utils.data.Dataset):
info.image = info.image.result() # future to image
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)
# remove image from memory
for info in batch:
info.image = None
# define ThreadPoolExecutor to load images in parallel
max_workers = min(os.cpu_count(), len(image_infos))
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
@@ -1397,7 +1401,17 @@ class BaseDataset(torch.utils.data.Dataset):
)
def get_image_size(self, image_path):
return imagesize.get(image_path)
# return imagesize.get(image_path)
image_size = imagesize.get(image_path)
if image_size[0] <= 0:
# imagesize doesn't work for some images, so use cv2
img = cv2.imread(image_path)
if img is not None:
image_size = (img.shape[1], img.shape[0])
else:
logger.warning(f"failed to get image size: {image_path}")
image_size = (0, 0)
return image_size
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
img = load_image(image_path, alpha_mask)
@@ -1615,7 +1629,6 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
)
text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs]
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
@@ -2511,6 +2524,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def get_resolutions(self) -> List[Tuple[int, int]]:
return [(dataset.width, dataset.height) for dataset in self.datasets]
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -5963,6 +5979,37 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict
def load_prompts(prompt_file: str) -> List[Dict]:
# read prompts
if prompt_file.endswith(".txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif prompt_file.endswith(".toml"):
with open(prompt_file, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif prompt_file.endswith(".json"):
with open(prompt_file, "r", encoding="utf-8") as f:
prompts = json.load(f)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
from library.train_util import line_to_prompt_dict
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
return prompts
def sample_images_common(
pipe_class,
accelerator: Accelerator,

View File

@@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
# region Logging
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
@@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
# endregion
# region PyTorch utils
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
@@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
@@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
# endregion
# TODO make inf_utils.py
# region Gradual Latent hires fix

839
networks/lora_sd3.py Normal file
View File

@@ -0,0 +1,839 @@
# temporary minimum implementation of LoRA
# SD3 doesn't have Conv2d, so we ignore it
# TODO commonize with the original/SD3/FLUX implementation
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from transformers import CLIPTextModelWithProjection, T5EncoderModel
import numpy as np
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
from library import sd3_models
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: sd3_models.SDVAE,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
mmdit,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
context_attn_dim = kwargs.get("context_attn_dim", None)
context_mlp_dim = kwargs.get("context_mlp_dim", None)
context_mod_dim = kwargs.get("context_mod_dim", None)
x_attn_dim = kwargs.get("x_attn_dim", None)
x_mlp_dim = kwargs.get("x_mlp_dim", None)
x_mod_dim = kwargs.get("x_mod_dim", None)
if context_attn_dim is not None:
context_attn_dim = int(context_attn_dim)
if context_mlp_dim is not None:
context_mlp_dim = int(context_mlp_dim)
if context_mod_dim is not None:
context_mod_dim = int(context_mod_dim)
if x_attn_dim is not None:
x_attn_dim = int(x_attn_dim)
if x_mlp_dim is not None:
x_mlp_dim = int(x_mlp_dim)
if x_mod_dim is not None:
x_mod_dim = int(x_mod_dim)
type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval?
assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)"
# double/single train blocks
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
"""
Parse a block selection string and return a list of booleans.
Args:
selection (str): A string specifying which blocks to select.
total_blocks (int): The total number of blocks available.
Returns:
List[bool]: A list of booleans indicating which blocks are selected.
"""
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start = int(start)
end = int(end)
assert 0 <= start < total_blocks, f"invalid start index: {start}"
assert 0 <= end < total_blocks, f"invalid end index: {end}"
assert start <= end, f"invalid range: {start}-{end}"
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks, f"invalid index: {index}"
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
train_t5xxl = None
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None:
train_t5xxl = False
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
def __init__(
self,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
unet: sd3_models.MMDiT,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.emb_dims = [0] * 6 # create emb_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
qkv_dim = 0
if self.split_qkv:
logger.info(f"split qkv for LoRA")
qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0)
if train_t5xxl:
logger.info(f"train T5XXL as well")
# create module instances
def create_modules(
is_mmdit: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_SD3
if is_mmdit
else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][
text_encoder_idx
]
)
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if not filter in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if is_mmdit and type_dims is not None:
# type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
identifier = [
("context_block", "attn"),
("context_block", "mlp"),
("context_block", "adaLN_modulation"),
("x_block", "attn"),
("x_block", "mlp"),
("x_block", "adaLN_modulation"),
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name:
# "lora_unet_joint_blocks_0_x_block_attn_proj..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if self.train_block_indices is not None and not self.train_block_indices[block_index]:
dim = 0
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
skipped.append(lora_name)
continue
# qkv split
split_dims = None
if is_mmdit and split_qkv:
if "joint_blocks" in lora_name and "qkv" in lora_name:
split_dims = [qkv_dim // 3] * 3
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
return loras, skipped
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
# create LoRA for U-Net
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE)
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
if self.emb_dims:
for filter, in_dim in zip(
[
"context_embedder",
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"x_embedder",
"y_embedder",
"final_layer_adaLN_modulation",
"final_layer_linear",
],
self.emb_dims,
):
# x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras)
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to split qkv
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)
# split qkv
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
continue
weight = state_dict[key]
lora_name = key.split(".")[0]
if "lora_down" in key and "weight" in key:
# dense weight (rank*3, in_dim)
split_weight = torch.chunk(weight, 3, dim=0)
for i, split_w in enumerate(split_weight):
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
del state_dict[key]
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
elif "lora_up" in key and "weight" in key:
# sparse weight (out_dim=sum(split_dims), rank*3)
rank = weight.size(1) // 3
i = 0
split_dim = weight.shape[0] // 3
for j in range(3):
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank]
i += split_dim
del state_dict[key]
# alpha is unchanged
return super().load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
split_dim, rank = up_weights[0].size()
qkv_dim = split_dim * 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(3):
up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j]
i += split_dim
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict
def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if (
key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5)
):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of three elements
# if float, use the same value for all three
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr, default_lr, default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]]
elif len(text_encoder_lr) == 2:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]]
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# split text encoder loras for te1 and te3
te1_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
]
te2_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
]
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te2_loras) > 0:
logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}")
params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te3_loras) > 0:
logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}")
params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -10,10 +10,13 @@ import numpy as np
import torch
from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3
init_ipex()
@@ -25,11 +28,14 @@ import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_sd3
from library.utils import load_safetensors
def get_noise(seed, latent):
generator = torch.manual_seed(seed)
return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype)
def get_noise(seed, latent, device="cpu"):
# generator = torch.manual_seed(seed)
generator = torch.Generator(device)
generator.manual_seed(seed)
return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device)
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
@@ -59,7 +65,7 @@ def do_sample(
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
guidance_scale: float,
cfg_scale: float,
dtype: torch.dtype,
device: str,
):
@@ -71,7 +77,7 @@ def do_sample(
latent = latent.to(dtype).to(device)
noise = get_noise(seed, latent).to(device)
noise = get_noise(seed, latent, device)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
@@ -100,12 +106,13 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
denoised = neg_out + (pos_out - neg_out) * cfg_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
@@ -122,230 +129,68 @@ def do_sample(
x = x.to(dtype)
latent = x
scale_factor = 1.5305
shift_factor = 0.0609
# def process_out(self, latent):
# return (latent / self.scale_factor) + self.shift_factor
latent = (latent / scale_factor) + shift_factor
latent = vae.process_out(latent)
return latent
if __name__ == "__main__":
target_height = 1024
target_width = 1024
# steps = 50 # 28 # 50
guidance_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--do_not_use_t5xxl", action="store_true")
parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
# parser.add_argument(
# "--lora_weights",
# type=str,
# nargs="*",
# default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
# )
# parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
sd3_dtype = torch.float32
if args.fp16:
sd3_dtype = torch.float16
elif args.bf16:
sd3_dtype = torch.bfloat16
# TODO test with separated safetenors files for each model
# load state dict
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
state_dict = load_file(args.ckpt_path)
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k, v in list(state_dict.items()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
else:
logger.info(f"Lodaing clip_g from {args.clip_g}...")
clip_g_sd = load_file(args.clip_g)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k, v in list(state_dict.items()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
else:
logger.info(f"Lodaing clip_l from {args.clip_l}...")
clip_l_sd = load_file(args.clip_l)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
if not args.do_not_use_t5xxl:
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k, v in list(state_dict.items()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
else:
logger.info("but not used")
for key in list(state_dict.keys()):
if key.startswith("text_encoders.t5xxl."):
state_dict.pop(key)
t5xxl_sd = None
elif args.t5xxl:
assert not args.do_not_use_t5xxl, "t5xxl is not used but specified"
logger.info(f"Lodaing t5xxl from {args.t5xxl}...")
t5xxl_sd = load_file(args.t5xxl)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
logger.info("t5xxl is not used")
t5xxl_sd = None
use_t5xxl = t5xxl_sd is not None
# MMDiT and VAE
vae_sd = {}
vae_prefix = "first_stage_model."
mmdit_prefix = "model.diffusion_model."
for k, v in list(state_dict.items()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
elif k.startswith(mmdit_prefix):
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load tokenizers
logger.info("Loading tokenizers...")
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
# load models
# logger.info("Create MMDiT from SD3 checkpoint...")
# mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict)
logger.info("Create MMDiT")
mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode)
logger.info("Loading state dict...")
info = mmdit.load_state_dict(state_dict)
logger.info(f"Loaded MMDiT: {info}")
logger.info(f"Move MMDiT to {device} and {sd3_dtype}...")
mmdit.to(device, dtype=sd3_dtype)
mmdit.eval()
# load VAE
logger.info("Create VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
logger.info(f"Move VAE to {device} and {sd3_dtype}...")
vae.to(device, dtype=sd3_dtype)
vae.eval()
# load text encoders
logger.info("Create clip_l")
clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded clip_l: {info}")
logger.info(f"Move clip_l to {device} and {sd3_dtype}...")
clip_l.to(device, dtype=sd3_dtype)
clip_l.eval()
logger.info(f"Set attn_mode to {args.attn_mode}...")
clip_l.set_attn_mode(args.attn_mode)
logger.info("Create clip_g")
clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded clip_g: {info}")
logger.info(f"Move clip_g to {device} and {sd3_dtype}...")
clip_g.to(device, dtype=sd3_dtype)
clip_g.eval()
logger.info(f"Set attn_mode to {args.attn_mode}...")
clip_g.set_attn_mode(args.attn_mode)
if use_t5xxl:
logger.info("Create t5xxl")
t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded t5xxl: {info}")
logger.info(f"Move t5xxl to {device} and {sd3_dtype}...")
t5xxl.to(device, dtype=sd3_dtype)
# t5xxl.to("cpu", dtype=torch.float32) # run on CPU
t5xxl.eval()
logger.info(f"Set attn_mode to {args.attn_mode}...")
t5xxl.set_attn_mode(args.attn_mode)
else:
t5xxl = None
def generate_image(
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
clip_l: CLIPTextModelWithProjection,
clip_g: CLIPTextModelWithProjection,
t5xxl: T5EncoderModel,
steps: int,
prompt: str,
seed: int,
target_width: int,
target_height: int,
device: str,
negative_prompt: str,
cfg_scale: float,
):
# prepare embeddings
logger.info("Encoding prompts...")
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# TODO support one-by-one offloading
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# attn masks are not used currently
if args.offload:
clip_l.to("cpu")
clip_g.to("cpu")
t5xxl.to("cpu")
# generate image
logger.info("Generating image...")
latent_sampled = do_sample(
target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device
)
mmdit.to(device)
latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device)
if args.offload:
mmdit.to("cpu")
# latent to image
vae.to(device)
with torch.no_grad():
image = vae.decode(latent_sampled)
if args.offload:
vae.to("cpu")
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
@@ -359,3 +204,204 @@ if __name__ == "__main__":
out_image.save(output_path)
logger.info(f"Saved image to {output_path}")
if __name__ == "__main__":
target_height = 1024
target_width = 1024
# steps = 50 # 28 # 50
# cfg_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--cfg_scale", type=float, default=5.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument("--output_dir", type=str, default=".")
# parser.add_argument("--do_not_use_t5xxl", action="store_true")
# parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
sd3_dtype = torch.float32
if args.fp16:
sd3_dtype = torch.float16
elif args.bf16:
sd3_dtype = torch.bfloat16
loading_device = "cpu" if args.offload else device
# load state dict
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
# state_dict = load_file(args.ckpt_path)
state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype)
# load text encoders
clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict)
clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict)
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict)
# MMDiT and VAE
vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict)
mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device)
clip_l.to(sd3_dtype)
clip_g.to(sd3_dtype)
t5xxl.to(sd3_dtype)
vae.to(sd3_dtype)
mmdit.to(sd3_dtype)
if not args.offload:
# make sure to move to the device: some tensors are created in the constructor on the CPU
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
vae.to(device)
mmdit.to(device)
clip_l.eval()
clip_g.eval()
t5xxl.eval()
mmdit.eval()
vae.eval()
# load tokenizers
logger.info("Loading tokenizers...")
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
args.steps,
args.prompt,
args.seed,
args.width,
args.height,
device,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = args.width
height = args.height
steps = None
cfg_scale = args.cfg_scale
while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed>"
" --n <negative prompt>, `--n -` for empty negative prompt"
"Options are kept for the next prompt. Current options:"
f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}"
)
prompt = input()
if prompt == "":
break
# parse options
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
steps if steps is not None else args.steps,
prompt,
seed if seed is not None else args.seed,
width,
height,
device,
negative_prompt if negative_prompt is not None else args.negative_prompt,
cfg_scale,
)
logger.info("Done!")

File diff suppressed because it is too large Load Diff

451
sd3_train_network.py Normal file
View File

@@ -0,0 +1,451 @@
import argparse
import copy
import math
import random
from typing import Any, Optional
import torch
from accelerate import Accelerator
from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
self.resolutions = train_dataset_group.get_resolutions()
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
state_dict = utils.load_safetensors(
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
self.model_type = mmdit.model_type
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)
if args.fp8_base:
# check dtype of model
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}")
elif mmdit.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 SD3 model")
clip_l = sd3_utils.load_clip_l(
args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_l.eval()
clip_g = sd3_utils.load_clip_g(
args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_g.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
vae = sd3_utils.load_vae(
args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit
def get_tokenize_strategy(self, args):
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5_dropout_rate,
)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.train_clip and not self.train_t5xxl:
return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
else:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_clip, self.train_clip, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
apply_lg_attn_mask=args.apply_lg_attn_mask,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[2].to(accelerator.device) # may be fp8
if text_encoders[2].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[2].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move CLIP-G back to cpu")
text_encoders[1].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[2].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[2].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
sd3_train_utils.sample_images(
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# shift 3.0 is the default value
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return sd3_models.SDVAE.process_in(latents)
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: flux_models.Flux,
network,
weight_dtype,
train_unet,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Predict the noise residual
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
if not args.apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
if not args.apply_t5_attn_mask:
t5_attn_mask = None
# call model
with accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = latents
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
context=context[diff_output_pr_indices],
y=lg_pooled[diff_output_pr_indices],
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices]
# weighting for differential output preservation is not needed because it is already applied
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, None, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type)
def update_metadata(self, metadata, args):
metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0 or index == 1: # CLIP-L/CLIP-G
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0 or index == 1: # CLIP-L/CLIP-G
clip_type = "CLIP-L" if index == 0 else "CLIP-G"
logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
sd3_train_utils.add_sd3_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = Sd3NetworkTrainer()
trainer.train(args)

View File

@@ -129,6 +129,7 @@ class NetworkTrainer:
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
"""
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
"""
return text_encoders
@@ -271,6 +272,9 @@ class NetworkTrainer:
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
# endregion
def train(self, args):
@@ -591,6 +595,7 @@ class NetworkTrainer:
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
unet.requires_grad_(False)
@@ -1028,9 +1033,9 @@ class NetworkTrainer:
# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
on_step_start_for_network = lambda *args, **kwargs: None
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
@@ -1111,7 +1116,10 @@ class NetworkTrainer:
continue
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
@@ -1143,7 +1151,9 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions: