diff --git a/README.md b/README.md index ad2791e7..aff78b2c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 training (WIP) +## FLUX.1 and SD3 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,8 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +- [FLUX.1 training](#flux1-training) +- [SD3 training](#sd3-training) + ### Recent Updates +Oct 31, 2024: + +- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. + Oct 19, 2024: - Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. @@ -139,7 +146,7 @@ Sep 1, 2024: Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. -### Contents +## FLUX.1 training - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) @@ -586,53 +593,177 @@ python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_fol ## SD3 training -SD3 training is done with `sd3_train.py`. +SD3.5L/M training is now available. -__Sep 1, 2024__: -- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! +### SD3 LoRA training -__Jul 27, 2024__: -- Latents and text encoder outputs caching mechanism is refactored significantly. - - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. - - With this change, dataset initialization is significantly faster, especially for large datasets. +The script is `sd3_train_network.py`. See `--help` for options. -- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. +SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format. -- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. +Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L). ---- +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py +--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name sd3-lora-name +``` +(The command is multi-line for readability. Please combine it into one line.) -`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. +The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below: -`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. - -`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. - -t5xxl works with `fp16` now. - -There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. - -`text_encoder_batch_size` is added experimentally for caching faster. - -```toml -learning_rate = 1e-6 # seems to depend on the batch size -optimizer_type = "adafactor" -optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] -cache_text_encoder_outputs = true -cache_text_encoder_outputs_to_disk = true -vae_batch_size = 1 -text_encoder_batch_size = 4 -cache_latents = true -cache_latents_to_disk = true +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -__2024/7/27:__ +`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training. -Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. -データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 +The trained LoRA model can be used with ComfyUI. -SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 +#### Key Options for SD3 LoRA training + +Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. + +- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training. +- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--clip_g` is the path to the CLIP-G model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. +- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. +- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. + +Other options are described below. + +#### Key Features for SD3 LoRA training + +1. CLIP-L, G and T5XXL LoRA Support: + - SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training. + - Remove `--network_train_unet_only` from your command. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time. + - T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The trained LoRA can be used with ComfyUI. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |MMDiT|`--network_train_unet_only`|-|o| + |MMDiT + CLIP-L + CLIP-G|-|-|o (*2)| + |MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L and G LoRA training. + - *3: Not tested yet. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Same as FLUX.1. + +4. CLIP-L/G and T5 Attention Mask Application: + - This function is planned to be implemented in the future. + +5. Multi-resolution Training Support: + - Only for SD3.5M. + - Same as FLUX.1 for data preparation. + - If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. + + +Technical details of multi-resolution training for SD3.5M: + +The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. + +This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf! + + +#### Specify rank for each layer in SD3 LoRA + +You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|context_attn_dim|attn in context_block| +|context_mlp_dim|mlp in context_block| +|context_mod_dim|adaLN_modulation in context_block| +|x_attn_dim|attn in x_block| +|x_mlp_dim|mlp in x_block| +|x_mod_dim|adaLN_modulation in x_block| + +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + +example: +``` +--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True" +``` + +You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "emb_dims=[2,3,4,5,6,7]" +``` + +Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`. + +#### Specify blocks to train in SD3 LoRA training + +You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. + +The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_block_indices=1,2,6-8" +``` + +### Inference for SD3 with LoRA model + +The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options. + +### SD3 fine-tuning + +Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following: + +- `--clip_g` is also available for SD3 fine-tuning. +- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning. +- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning. +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training. +- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning. +- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training. + - CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option. + - If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available. + +### Extract LoRA from SD3 Models + +Not available yet. + +### Convert SD3 LoRA + +Not available yet. + +### Merge LoRA to SD3 checkpoint + +Not available yet. --- diff --git a/flux_train.py b/flux_train.py index 91ae3af5..79c44d7b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -29,7 +29,7 @@ init_ipex() from accelerate.utils import set_seed from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util @@ -241,7 +241,7 @@ def train(args): text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: diff --git a/flux_train_network.py b/flux_train_network.py index 9cc8811b..2b71a897 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -231,7 +231,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - prompts = sd3_train_utils.load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: @@ -363,7 +363,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index b3c9184f..fa673a2f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -15,7 +15,6 @@ from PIL import Image from safetensors.torch import save_file from library import flux_models, flux_utils, strategy_base, train_util -from library.sd3_train_utils import load_prompts from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -70,7 +69,7 @@ def sample_images( text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) diff --git a/library/flux_utils.py b/library/flux_utils.py index 4403835f..f3093615 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -10,40 +10,21 @@ from safetensors import safe_open from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config -from library import flux_models - -from library.utils import setup_logging, MemoryEfficientSafeOpen +from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +from library import flux_models +from library.utils import load_safetensors + MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" -# temporary copy from sd3_utils TODO refactor -def load_safetensors( - path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 -): - if disable_mmap: - # return safetensors.torch.load(open(path, "rb").read()) - # use experimental loader - logger.info(f"Loading without mmap (experimental)") - state_dict = {} - with MemoryEfficientSafeOpen(path) as f: - for key in f.keys(): - state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) - return state_dict - else: - try: - return load_file(path, device=device) - except: - return load_file(path) # prevent device invalid Error - - def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 @@ -172,8 +153,14 @@ def load_ae( return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: - logger.info("Building CLIP") +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", "architectures": ["CLIPModel"], @@ -266,15 +253,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev with init_empty_weights(): clip = CLIPTextModel._from_config(config) - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) - logger.info(f"Loaded CLIP: {info}") + logger.info(f"Loaded CLIP-L: {info}") return clip def load_t5xxl( - ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, ) -> T5EncoderModel: T5_CONFIG_JSON = """ { @@ -314,8 +308,11 @@ def load_t5xxl( with init_empty_weights(): t5xxl = T5EncoderModel._from_config(config) - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index ad72ec00..8896c047 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -57,8 +57,8 @@ ARCH_SD_V1 = "stable-diffusion-v1" ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" -ARCH_SD3_M = "stable-diffusion-3-medium" -ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. +# ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" ARCH_FLUX_1_UNKNOWN = "flux-1" @@ -140,10 +140,7 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE elif sd3 is not None: - if sd3 == "m": - arch = ARCH_SD3_M - else: - arch = ARCH_SD3_UNKNOWN + arch = ARCH_SD3_M + "-" + sd3 elif flux is not None: if flux == "dev": arch = ARCH_FLUX_1_DEV diff --git a/library/sd3_models.py b/library/sd3_models.py index ec704dcb..b09a57db 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -4,6 +4,8 @@ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! from ast import Tuple +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from functools import partial import math from types import SimpleNamespace @@ -15,6 +17,9 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast + +from library.device_utils import clean_memory_on_device + from .utils import setup_logging setup_logging() @@ -35,141 +40,23 @@ except: memory_efficient_attention = None -# region tokenizer -class SDTokenizer: - def __init__( - self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None - ): - """ - サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 - Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. - """ - self.tokenizer: CLIPTokenizer = tokenizer - self.max_length = max_length - self.min_length = min_length - empty = self.tokenizer("")["input_ids"] - if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] - self.end_token = empty[1] - else: - self.tokens_start = 0 - self.start_token = None - self.end_token = empty[0] - self.pad_with_end = pad_with_end - self.pad_to_max_length = pad_to_max_length - vocab = self.tokenizer.get_vocab() - self.inv_vocab = {v: k for k, v in vocab.items()} - self.max_word_length = 8 - - def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: - """ - Tokenize the text without weights. - """ - if type(text) == str: - text = [text] - batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") - # return tokens["input_ids"] - - pad_token = self.end_token if self.pad_with_end else 0 - for tokens in batch_tokens["input_ids"]: - assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" - - def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): - """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. - The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" - """ - ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 - 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ - """ - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0)) - to_tokenize = text.replace("\n", " ").split(" ") - to_tokenize = [x for x in to_tokenize if x != ""] - for word in to_tokenize: - batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) - batch.append((self.end_token, 1.0)) - print(len(batch), self.max_length, self.min_length) - if self.pad_to_max_length: - batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) - - # truncate to max_length - print( - f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}" - ) - if truncate_to_max_length and len(batch) > self.max_length: - batch = batch[: self.max_length] - if truncate_length is not None and len(batch) > truncate_length: - batch = batch[:truncate_length] - - return [batch] - - -class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" - - def __init__(self): - super().__init__( - pad_with_end=False, - tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), - has_start_token=False, - pad_to_max_length=False, - max_length=99999999, - min_length=77, - ) - - -class SDXLClipGTokenizer(SDTokenizer): - def __init__(self, tokenizer): - super().__init__(pad_with_end=False, tokenizer=tokenizer) - - -class SD3Tokenizer: - def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): - if t5xxl_max_length is None: - t5xxl_max_length = 256 - - # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI - clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) - self.clip_g = SDXLClipGTokenizer(clip_tokenizer) - # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - self.t5xxl = T5XXLTokenizer() if t5xxl else None - # t5xxl has 99999999 max length, clip has 77 - self.t5xxl_max_length = t5xxl_max_length - - def tokenize_with_weights(self, text: str): - return ( - self.clip_l.tokenize_with_weights(text), - self.clip_g.tokenize_with_weights(text), - ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) - if self.t5xxl is not None - else None - ), - ) - - def tokenize(self, text: str): - return ( - self.clip_l.tokenize(text), - self.clip_g.tokenize(text), - (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), - ) - - -# endregion - # region mmdit +@dataclass +class SD3Params: + patch_size: int + depth: int + num_patches: int + pos_embed_max_size: int + adm_in_channels: int + qk_norm: Optional[str] + x_block_self_attn_layers: list[int] + context_embedder_in_features: int + context_embedder_out_features: int + model_type: str + + def get_2d_sincos_pos_embed( embed_dim, grid_size, @@ -201,6 +88,78 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): return emb +def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): + """ + This function is contributed by KohakuBlueleaf. Thanks for the contribution! + + Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions + when the resolution differs from the training resolution. + + Args: + embed_dim (int): Dimension of the positional embedding. + grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. + cls_token (bool): Whether to include class token. Defaults to False. + extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. + sample_size (int): Reference resolution (typically training resolution). Defaults to 64. + base_size (int): Base grid size used during training. Defaults to 16. + + Returns: + numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or + (H*W + extra_tokens, embed_dim) if cls_token is True. + """ + # Convert grid_size to tuple if it's an integer + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + # Create normalized grid coordinates (0 to 1) + grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] + + # Calculate scaling factors for height and width + # This ensures that the central region matches the original resolution's embeddings + scale_h = base_size * grid_size[0] / (sample_size) + scale_w = base_size * grid_size[1] / (sample_size) + + # Calculate shift values to center the original resolution's embedding region + # This ensures that the central sample_size x sample_size region has similar + # positional embeddings to the original resolution + shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) + shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) + + # Apply scaling and shifting to create the final grid coordinates + grid_h = grid_h * scale_h - shift_h + grid_w = grid_w * scale_w - shift_w + + # Create 2D grid using meshgrid (note: w goes first) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + # # Calculate the starting indices for the central region + # # This is used for debugging/visualization of the central region + # st_h = (grid_size[0] - sample_size) // 2 + # st_w = (grid_size[1] - sample_size) // 2 + # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) + + # Reshape grid for positional embedding calculation + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + # Generate the sinusoidal positional embeddings + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + # Add zeros for extra tokens (e.g., [CLS] token) if required + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +# if __name__ == "__main__": +# # This is what you get when you load SD3.5 state dict +# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( +# 1536, [384, 384], sample_size=64, base_size=16 +# )).float().unsqueeze(0) + + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position @@ -286,10 +245,6 @@ def timestep_embedding(t, dim, max_period=10000): return embedding -def rmsnorm(x, eps=1e-6): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - class PatchEmbed(nn.Module): def __init__( self, @@ -301,8 +256,9 @@ class PatchEmbed(nn.Module): flatten=True, bias=True, strict_img_size=True, - dynamic_img_pad=True, + dynamic_img_pad=False, ): + # dynamic_img_pad and norm is omitted in SD3.5 super().__init__() self.patch_size = patch_size self.flatten = flatten @@ -432,6 +388,10 @@ class Embedder(nn.Module): return self.mlp(x) +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + class RMSNorm(torch.nn.Module): def __init__( self, @@ -604,53 +564,6 @@ def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): return scores -class SelfAttention(AttentionLinears): - def __init__(self, dim, num_heads=8, mode="xformers"): - super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) - assert mode in MEMORY_LAYOUTS - self.head_dim = dim // num_heads - self.attn_mode = mode - - def set_attn_mode(self, mode): - self.attn_mode = mode - - def forward(self, x): - q, k, v = self.pre_attention(x) - attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) - return self.post_attention(attn_score) - - -class TransformerBlock(nn.Module): - def __init__(self, context_size, mode="xformers"): - super().__init__() - self.context_size = context_size - self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - self.attn = SelfAttention(context_size, mode=mode) - self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - self.mlp = MLP( - in_features=context_size, - hidden_features=context_size * 4, - act_layer=lambda: nn.GELU(approximate="tanh"), - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, context_size, num_layers, mode="xformers"): - super().__init__() - self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) - self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return self.norm(x) - - # DismantledBlock in mmdit.py class SingleDiTBlock(nn.Module): """ @@ -669,6 +582,7 @@ class SingleDiTBlock(nn.Module): scale_mod_only: bool = False, swiglu: bool = False, qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, **block_kwargs, ): super().__init__() @@ -678,13 +592,14 @@ class SingleDiTBlock(nn.Module): self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) else: self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = AttentionLinears( - dim=hidden_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - pre_only=pre_only, - qk_norm=qk_norm, - ) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + if not pre_only: if not rmsnorm: self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -705,7 +620,9 @@ class SingleDiTBlock(nn.Module): multiple_of=256, ) self.scale_mod_only = scale_mod_only - if not scale_mod_only: + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: n_mods = 6 if not pre_only else 2 else: n_mods = 4 if not pre_only else 1 @@ -715,63 +632,64 @@ class SingleDiTBlock(nn.Module): def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: if not self.pre_only: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(6, dim=-1) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) else: shift_msa = None shift_mlp = None - ( - scale_msa, - gate_msa, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(4, dim=-1) + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) - return qkv, ( - x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - ) = self.adaLN_modulation( - c - ).chunk(2, dim=-1) + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) else: shift_msa = None scale_msa = self.adaLN_modulation(c) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) return qkv, None + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): assert not self.pre_only x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + # JointBlock + block_mixing in mmdit.py class MMDiTBlock(nn.Module): def __init__(self, *args, **kwargs): super().__init__() pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) - self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + self.head_dim = self.x_block.attn.head_dim self.mode = self.x_block.attn_mode self.gradient_checkpointing = False @@ -781,7 +699,11 @@ class MMDiTBlock(nn.Module): def _forward(self, context, x, c): ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) - x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) ctx_len = ctx_qkv[0].size(1) @@ -793,11 +715,18 @@ class MMDiTBlock(nn.Module): ctx_attn_out = attn[:, :ctx_len] x_attn_out = attn[:, ctx_len:] - x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + if not self.context_block.pre_only: context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) else: context = None + return context, x def forward(self, *args, **kwargs): @@ -812,6 +741,9 @@ class MMDiT(nn.Module): Diffusion model with a Transformer backbone. """ + # prepare pos_embed for latent size * 2 + POS_EMBED_MAX_RATIO = 1.5 + def __init__( self, input_size: int = 32, @@ -823,7 +755,8 @@ class MMDiT(nn.Module): mlp_ratio: float = 4.0, learn_sigma: bool = False, adm_in_channels: Optional[int] = None, - context_embedder_config: Optional[Dict] = None, + context_embedder_in_features: Optional[int] = None, + context_embedder_out_features: Optional[int] = None, use_checkpoint: bool = False, register_length: int = 0, attn_mode: str = "torch", @@ -836,11 +769,15 @@ class MMDiT(nn.Module): pos_embed_max_size: Optional[int] = None, num_patches=None, qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], qkv_bias: bool = True, - context_processor_layers=None, - context_size=4096, + pos_emb_random_crop_rate: float = 0.0, + use_scaled_pos_embed: bool = False, + pos_embed_latent_sizes: Optional[list[int]] = None, + model_type: str = "sd3m", ): super().__init__() + self._model_type = model_type self.learn_sigma = learn_sigma self.in_channels = in_channels default_out_channels = in_channels * 2 if learn_sigma else in_channels @@ -849,6 +786,8 @@ class MMDiT(nn.Module): self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_offset = pos_embed_offset self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate self.gradient_checkpointing = use_checkpoint # hidden_size = default(hidden_size, 64 * depth) @@ -860,6 +799,8 @@ class MMDiT(nn.Module): self.num_heads = num_heads + self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) + self.x_embedder = PatchEmbed( input_size, patch_size, @@ -875,12 +816,11 @@ class MMDiT(nn.Module): assert isinstance(adm_in_channels, int) self.y_embedder = Embedder(adm_in_channels, self.hidden_size) - if context_processor_layers is not None: - self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + if context_embedder_in_features is not None: + self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) else: - self.context_processor = None + self.context_embedder = nn.Identity() - self.context_embedder = nn.Linear(context_size, self.hidden_size) self.register_length = register_length if self.register_length > 0: self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) @@ -910,6 +850,7 @@ class MMDiT(nn.Module): scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), ) for i in range(depth) ] @@ -920,9 +861,50 @@ class MMDiT(nn.Module): self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) # self.initialize_weights() + self.blocks_to_swap = None + self.thread_pool: Optional[ThreadPoolExecutor] = None + + def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): + self.use_scaled_pos_embed = use_scaled_pos_embed + + if self.use_scaled_pos_embed: + # remove pos_embed to free up memory up to 0.4 GB + self.pos_embed = None + + # remove duplcates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) + latent_sizes = sorted(latent_sizes) + + patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] + + # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape + max_areas = [] + for i in range(1, len(patched_sizes)): + prev_area = patched_sizes[i - 1] ** 2 + area = patched_sizes[i] ** 2 + max_areas.append((prev_area + area) // 2) + + # area of the last latent size, if the latent size exceeds this, error will be raised + max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) + # print("max_areas", max_areas) + + self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] + + self.resolution_pos_embeds = {} + for patched_size in patched_sizes: + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") + + else: + self.resolution_area_to_latent_size = None + self.resolution_pos_embeds = None + @property def model_type(self): - return "m" # only support medium + return self._model_type @property def device(self): @@ -988,18 +970,27 @@ class MMDiT(nn.Module): nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, h, w, device=None): + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): p = self.x_embedder.patch_size # patched size h = (h + 1) // p w = (w + 1) // p - if self.pos_embed is None: + if self.pos_embed is None: # should not happen return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) - top = (self.pos_embed_max_size - h) // 2 - left = (self.pos_embed_max_size - w) // 2 + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + spatial_pos_embed = self.pos_embed.reshape( 1, self.pos_embed_max_size, @@ -1010,6 +1001,89 @@ class MMDiT(nn.Module): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed + def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + + # select pos_embed size based on area + area = h * w + patched_size = None + for area_, patched_size_ in self.resolution_area_to_latent_size: + if area <= area_: + patched_size = patched_size_ + break + if patched_size is None: + raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) + if h > pos_embed_size or w > pos_embed_size: + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size + logger.warning( + f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + ) + pos_embed_size = max(h, w) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") + + if not random_crop: + top = (pos_embed_size - h) // 2 + left = (pos_embed_size - w) // 2 + else: + top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() + left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() + + if pos_embed.device != device: + pos_embed = pos_embed.to(device) + # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. + self.resolution_pos_embeds[patched_size] = pos_embed # update device + if pos_embed.dtype != dtype: + pos_embed = pos_embed.to(dtype) + self.resolution_pos_embeds[patched_size] = pos_embed # update dtype + + spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + # print( + # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" + # ) + return spatial_pos_embed + + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + self.thread_pool = ThreadPoolExecutor(max_workers=n) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return + num_blocks = len(self.joint_blocks) + for i in range(num_blocks - self.blocks_to_swap): + self.joint_blocks[i].to(self.device) + for i in range(num_blocks - self.blocks_to_swap, num_blocks): + self.joint_blocks[i].to("cpu") + clean_memory_on_device(self.device) + def forward( self, x: torch.Tensor, @@ -1023,12 +1097,21 @@ class MMDiT(nn.Module): t: (N,) tensor of diffusion timesteps y: (N, D) tensor of class labels """ - - if self.context_processor is not None: - context = self.context_processor(context) + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) B, C, H, W = x.shape - x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + + # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + if not self.use_scaled_pos_embed: + pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + else: + # print(f"Using scaled pos_embed for size {H}x{W}") + pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) + x = self.x_embedder(x) + pos_embed + del pos_embed + c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) @@ -1046,28 +1129,71 @@ class MMDiT(nn.Module): 1, ) - for block in self.joint_blocks: - context, x = block(context, x, c) + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + futures = {} + + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + block_to_cuda.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + block_to_cpu = self.joint_blocks[block_idx_to_cpu] + block_to_cuda = self.joint_blocks[block_idx_to_cuda] + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + + for block_idx, block in enumerate(self.joint_blocks): + wait_for_blocks_move(block_idx, futures) + + context, x = block(context, x, c) + + if block_idx < self.blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] -def create_mmdit_sd3_medium_configs(attn_mode: str): - # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, - # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': - # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} +def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: mmdit = MMDiT( input_size=None, - pos_embed_max_size=192, - patch_size=2, + pos_embed_max_size=params.pos_embed_max_size, + patch_size=params.patch_size, in_channels=16, - adm_in_channels=2048, - depth=24, + adm_in_channels=params.adm_in_channels, + context_embedder_in_features=params.context_embedder_in_features, + context_embedder_out_features=params.context_embedder_out_features, + depth=params.depth, mlp_ratio=4, - qk_norm=None, - num_patches=36864, - context_size=4096, + qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, + num_patches=params.num_patches, attn_mode=attn_mode, + model_type=params.model_type, ) return mmdit @@ -1075,7 +1201,6 @@ def create_mmdit_sd3_medium_configs(attn_mode: str): # endregion # region VAE -# TODO support xformers VAE_SCALE_FACTOR = 1.5305 VAE_SHIFT_FACTOR = 0.0609 @@ -1322,759 +1447,4 @@ class SDVAE(torch.nn.Module): return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR -class VAEOutput: - def __init__(self, latent): - self.latent = latent - - @property - def latent_dist(self): - return self - - def sample(self): - return self.latent - - -class VAEWrapper: - def __init__(self, vae): - self.vae = vae - - @property - def device(self): - return self.vae.device - - @property - def dtype(self): - return self.vae.dtype - - # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - def encode(self, image): - return VAEOutput(self.vae.encode(image)) - - -# endregion - - -# region Text Encoder -class CLIPAttention(torch.nn.Module): - def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): - super().__init__() - self.heads = heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.attn_mode = mode - - def set_attn_mode(self, mode): - self.attn_mode = mode - - def forward(self, x, mask=None): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) - return self.out_proj(out) - - -ACTIVATIONS = { - "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), - # "gelu": torch.nn.functional.gelu, - "gelu": lambda: nn.GELU(), -} - - -class CLIPLayer(torch.nn.Module): - def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): - super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) - # self.mlp = Mlp( - # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device - # ) - self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) - self.mlp.to(device=device, dtype=dtype) - - def forward(self, x, mask=None): - x += self.self_attn(self.layer_norm1(x), mask) - x += self.mlp(self.layer_norm2(x)) - return x - - -class CLIPEncoder(torch.nn.Module): - def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): - super().__init__() - self.layers = torch.nn.ModuleList( - [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] - ) - - def forward(self, x, mask=None, intermediate_output=None): - if intermediate_output is not None: - if intermediate_output < 0: - intermediate_output = len(self.layers) + intermediate_output - intermediate = None - for i, l in enumerate(self.layers): - x = l(x, mask) - if i == intermediate_output: - intermediate = x.clone() - return x, intermediate - - -class CLIPEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): - super().__init__() - self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) - self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) - - def forward(self, input_tokens): - return self.token_embedding(input_tokens) + self.position_embedding.weight - - -class CLIPTextModel_(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - num_layers = config_dict["num_hidden_layers"] - embed_dim = config_dict["hidden_size"] - heads = config_dict["num_attention_heads"] - intermediate_size = config_dict["intermediate_size"] - intermediate_activation = config_dict["hidden_act"] - super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) - self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - - def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): - x = self.embeddings(input_tokens) - - if x.dtype == torch.bfloat16: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) - causal_mask = causal_mask.to(dtype=x.dtype) - else: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - - x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) - x = self.final_layer_norm(x) - if i is not None and final_layer_norm_intermediate: - i = self.final_layer_norm(i) - pooled_output = x[ - torch.arange(x.shape[0], device=x.device), - input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), - ] - return x, i, pooled_output - - -class CLIPTextModel(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_hidden_layers"] - self.text_model = CLIPTextModel_(config_dict, dtype, device) - embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) - self.text_projection.weight.copy_(torch.eye(embed_dim)) - self.dtype = dtype - - def get_input_embeddings(self): - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, embeddings): - self.text_model.embeddings.token_embedding = embeddings - - def forward(self, *args, **kwargs): - x = self.text_model(*args, **kwargs) - out = self.text_projection(x[2]) - return (x[0], x[1], out, x[2]) - - -class ClipTokenWeightEncoder: - # def encode_token_weights(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - # out, pooled = self([tokens]) - # if pooled is not None: - # first_pooled = pooled[0:1] - # else: - # first_pooled = pooled - # output = [out[0:1]] - # return torch.cat(output, dim=-2), first_pooled - - # fix to support batched inputs - # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] - def encode_token_weights(self, list_of_token_weight_pairs): - has_batch = isinstance(list_of_token_weight_pairs[0][0], list) - - if has_batch: - list_of_tokens = [] - for pairs in list_of_token_weight_pairs: - tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] - list_of_tokens.append(tokens) - else: - if isinstance(list_of_token_weight_pairs[0], torch.Tensor): - list_of_tokens = [list(list_of_token_weight_pairs[0])] - else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] - - out, pooled = self(list_of_tokens) - if has_batch: - return out, pooled - else: - if pooled is not None: - first_pooled = pooled[0:1] - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2), first_pooled - - -class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - device="cpu", - max_length=77, - layer="last", - layer_idx=None, - textmodel_json_config=None, - dtype=None, - model_class=CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, - layer_norm_hidden_state=True, - return_projected_pooled=True, - ): - super().__init__() - assert layer in self.LAYERS - self.transformer = model_class(textmodel_json_config, dtype, device) - self.num_layers = self.transformer.num_layers - self.max_length = max_length - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.layer_norm_hidden_state = layer_norm_hidden_state - self.return_projected_pooled = return_projected_pooled - if layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < self.num_layers - self.set_clip_options({"layer": layer_idx}) - self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def gradient_checkpointing_enable(self): - logger.warning("Gradient checkpointing is not supported for this model") - - def set_attn_mode(self, mode): - raise NotImplementedError("This model does not support setting the attention mode") - - def set_clip_options(self, options): - layer_idx = options.get("layer", self.layer_idx) - self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: - self.layer = "last" - else: - self.layer = "hidden" - self.layer_idx = layer_idx - - def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = torch.LongTensor(tokens).to(device) - outputs = self.transformer( - tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state - ) - self.transformer.set_input_embeddings(backup_embeds) - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] - pooled_output = None - if len(outputs) >= 3: - if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: - pooled_output = outputs[3].float() - elif outputs[2] is not None: - pooled_output = outputs[2].float() - return z.float(), pooled_output - - def set_attn_mode(self, mode): - clip_text_model = self.transformer.text_model - for layer in clip_text_model.encoder.layers: - layer.self_attn.set_attn_mode(mode) - - -class SDXLClipG(SDClipModel): - """Wraps the CLIP-G model into the SD-CLIP-Model interface""" - - def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): - if layer == "penultimate": - layer = "hidden" - layer_idx = -2 - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, - layer_norm_hidden_state=False, - ) - - def set_attn_mode(self, mode): - clip_text_model = self.transformer.text_model - for layer in clip_text_model.encoder.layers: - layer.self_attn.set_attn_mode(mode) - - -class T5XXLModel(SDClipModel): - """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" - - def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"end": 1, "pad": 0}, - model_class=T5, - ) - - def set_attn_mode(self, mode): - t5: T5 = self.transformer - for t5block in t5.encoder.block: - t5block: T5Block - t5layer: T5LayerSelfAttention = t5block.layer[0] - t5SaSa: T5Attention = t5layer.SelfAttention - t5SaSa.set_attn_mode(mode) - - -################################################################################################# -### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl -################################################################################################# - -""" -class T5XXLTokenizer(SDTokenizer): - ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" - - def __init__(self): - super().__init__( - pad_with_end=False, - tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), - has_start_token=False, - pad_to_max_length=False, - max_length=99999999, - min_length=77, - ) -""" - - -class T5LayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) - self.variance_epsilon = eps - - # def forward(self, x): - # variance = x.pow(2).mean(-1, keepdim=True) - # x = x * torch.rsqrt(variance + self.variance_epsilon) - # return self.weight.to(device=x.device, dtype=x.dtype) * x - - # copy from transformers' T5LayerNorm - def forward(self, hidden_states): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class T5DenseGatedActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device): - super().__init__() - self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) - - def forward(self, x): - hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") - hidden_linear = self.wi_1(x) - x = hidden_gelu * hidden_linear - x = self.wo(x) - return x - - -class T5LayerFF(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device): - super().__init__() - self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) - self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, x): - forwarded_states = self.layer_norm(x) - forwarded_states = self.DenseReluDense(forwarded_states) - x += forwarded_states - return x - - -class T5Attention(torch.nn.Module): - def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) - self.num_heads = num_heads - self.relative_attention_bias = None - if relative_attention_bias: - self.relative_attention_num_buckets = 32 - self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) - - self.attn_mode = "xformers" # TODO 何とかする - - def set_attn_mode(self, mode): - self.attn_mode = mode - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) - ) - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - def compute_bias(self, query_length, key_length, device): - """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=True, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) - return values - - def forward(self, x, past_bias=None): - q = self.q(x) - k = self.k(x) - v = self.v(x) - if self.relative_attention_bias is not None: - past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) - if past_bias is not None: - mask = past_bias - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) - return self.o(out), past_bias - - -class T5LayerSelfAttention(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) - self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, x, past_bias=None): - output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) - x += output - return x, past_bias - - -class T5Block(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - self.layer = torch.nn.ModuleList() - self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) - self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) - - def forward(self, x, past_bias=None): - x, past_bias = self.layer[0](x, past_bias) - - # copy from transformers' T5Block - # clamp inf values to enable fp16 training - if x.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(x).any(), - torch.finfo(x.dtype).max - 1000, - torch.finfo(x.dtype).max, - ) - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - - x = self.layer[-1](x) - # clamp inf values to enable fp16 training - if x.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(x).any(), - torch.finfo(x.dtype).max - 1000, - torch.finfo(x.dtype).max, - ) - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - - return x, past_bias - - -class T5Stack(torch.nn.Module): - def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): - super().__init__() - self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) - self.block = torch.nn.ModuleList( - [ - T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) - for i in range(num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): - intermediate = None - x = self.embed_tokens(input_ids) - past_bias = None - for i, l in enumerate(self.block): - # uncomment to debug layerwise output: fp16 may cause issues - # print(i, x.mean(), x.std()) - x, past_bias = l(x, past_bias) - if i == intermediate_output: - intermediate = x.clone() - # print(x.mean(), x.std()) - x = self.final_layer_norm(x) - if intermediate is not None and final_layer_norm_intermediate: - intermediate = self.final_layer_norm(intermediate) - # print(x.mean(), x.std()) - return x, intermediate - - -class T5(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_layers"] - self.encoder = T5Stack( - self.num_layers, - config_dict["d_model"], - config_dict["d_model"], - config_dict["d_ff"], - config_dict["num_heads"], - config_dict["vocab_size"], - dtype, - device, - ) - self.dtype = dtype - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def set_input_embeddings(self, embeddings): - self.encoder.embed_tokens = embeddings - - def forward(self, *args, **kwargs): - return self.encoder(*args, **kwargs) - - -def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): - r""" - state_dict is not loaded, but updated with missing keys - """ - CLIPL_CONFIG = { - "hidden_act": "quick_gelu", - "hidden_size": 768, - "intermediate_size": 3072, - "num_attention_heads": 12, - "num_hidden_layers": 12, - } - with torch.no_grad(): - clip_l = SDClipModel( - layer="hidden", - layer_idx=-2, - device=device, - dtype=dtype, - layer_norm_hidden_state=False, - return_projected_pooled=False, - textmodel_json_config=CLIPL_CONFIG, - ) - clip_l.gradient_checkpointing_enable() - if state_dict is not None: - # update state_dict if provided to include logit_scale and text_projection.weight avoid errors - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = clip_l.logit_scale - if "transformer.text_projection.weight" not in state_dict: - state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight - return clip_l - - -def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): - r""" - state_dict is not loaded, but updated with missing keys - """ - CLIPG_CONFIG = { - "hidden_act": "gelu", - "hidden_size": 1280, - "intermediate_size": 5120, - "num_attention_heads": 20, - "num_hidden_layers": 32, - } - with torch.no_grad(): - clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) - if state_dict is not None: - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = clip_g.logit_scale - return clip_g - - -def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: - T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} - with torch.no_grad(): - t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) - if state_dict is not None: - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = t5.logit_scale - if "transformer.shared.weight" in state_dict: - state_dict.pop("transformer.shared.weight") - return t5 - - -""" - # snippet for using the T5 model from transformers - - from transformers import T5EncoderModel, T5Config - import accelerate - import json - - T5_CONFIG_JSON = "" -{ - "architectures": [ - "T5EncoderModel" - ], - "classifier_dropout": 0.0, - "d_ff": 10240, - "d_kv": 64, - "d_model": 4096, - "decoder_start_token_id": 0, - "dense_act_fn": "gelu_new", - "dropout_rate": 0.1, - "eos_token_id": 1, - "feed_forward_proj": "gated-gelu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "is_gated_act": true, - "layer_norm_epsilon": 1e-06, - "model_type": "t5", - "num_decoder_layers": 24, - "num_heads": 64, - "num_layers": 24, - "output_past": true, - "pad_token_id": 0, - "relative_attention_max_distance": 128, - "relative_attention_num_buckets": 32, - "tie_word_embeddings": false, - "torch_dtype": "float16", - "transformers_version": "4.41.2", - "use_cache": true, - "vocab_size": 32128 -} -"" - config = json.loads(T5_CONFIG_JSON) - config = T5Config(**config) - - # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") - # print(model.config) - # # model(**load_model.config) - - # with accelerate.init_empty_weights(): - model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) - for key in list(state_dict.keys()): - if key.startswith("transformer."): - new_key = key[len("transformer.") :] - state_dict[new_key] = state_dict.pop(key) - - info = model.load_state_dict(state_dict) - print(info) - model.set_attn_mode = lambda x: None - # model.to("cpu") - - _self = model - - def enc(list_of_token_weight_pairs): - has_batch = isinstance(list_of_token_weight_pairs[0][0], list) - - if has_batch: - list_of_tokens = [] - for pairs in list_of_token_weight_pairs: - tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] - list_of_tokens.append(tokens) - else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] - - list_of_tokens = np.array(list_of_tokens) - list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) - out = _self(list_of_tokens) - pooled = None - if has_batch: - return out, pooled - else: - if pooled is not None: - first_pooled = pooled[0:1] - else: - first_pooled = pooled - return out[0], first_pooled - # output = [out[0:1]] - # return torch.cat(output, dim=-2), first_pooled - - model.encode_token_weights = enc - - return model -""" - # endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index e819d440..69878750 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -11,8 +11,8 @@ from safetensors.torch import save_file from accelerate import Accelerator, PartialState from tqdm import tqdm from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel -from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -28,60 +28,16 @@ import logging logger = logging.getLogger(__name__) -from .sdxl_train_util import match_mixed_precision - - -def load_target_model( - model_type: str, - args: argparse.Namespace, - state_dict: dict, - accelerator: Accelerator, - attn_mode: str, - model_dtype: Optional[torch.dtype], - device: Optional[torch.device], -) -> Union[ - sd3_models.MMDiT, - Optional[sd3_models.SDClipModel], - Optional[sd3_models.SDXLClipG], - Optional[sd3_models.T5XXLModel], - sd3_models.SDVAE, -]: - loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") - - for pi in range(accelerator.state.num_processes): - if pi == accelerator.state.local_process_index: - logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - - if model_type == "mmdit": - model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) - elif model_type == "clip_l": - model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) - elif model_type == "clip_g": - model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) - elif model_type == "t5xxl": - model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) - elif model_type == "vae": - model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) - else: - raise ValueError(f"Unknown model type: {model_type}") - - # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device - if args.lowram: - model = model.to(accelerator.device) - - clean_memory_on_device(accelerator.device) - accelerator.wait_for_everyone() - - return model +from library import sd3_models, sd3_utils, strategy_base, train_util def save_models( ckpt_path: str, - mmdit: sd3_models.MMDiT, - vae: sd3_models.SDVAE, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: Optional[sd3_models.MMDiT], + vae: Optional[sd3_models.SDVAE], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None, ): @@ -101,24 +57,42 @@ def save_models( update_sd("model.diffusion_model.", mmdit.state_dict()) update_sd("first_stage_model.", vae.state_dict()) - if clip_l is not None: - update_sd("text_encoders.clip_l.", clip_l.state_dict()) - if clip_g is not None: - update_sd("text_encoders.clip_g.", clip_g.state_dict()) - if t5xxl is not None: - update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + # do not support unified checkpoint format for now + # if clip_l is not None: + # update_sd("text_encoders.clip_l.", clip_l.state_dict()) + # if clip_g is not None: + # update_sd("text_encoders.clip_g.", clip_g.state_dict()) + # if t5xxl is not None: + # update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) save_file(state_dict, ckpt_path, metadata=sai_metadata) + if clip_l is not None: + clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors") + save_file(clip_l.state_dict(), clip_l_path) + if clip_g is not None: + clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors") + save_file(clip_g.state_dict(), clip_g_path) + if t5xxl is not None: + t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) + def save_sd3_model_on_train_end( args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], mmdit: sd3_models.MMDiT, vae: sd3_models.SDVAE, ): @@ -141,9 +115,9 @@ def save_sd3_model_on_epoch_end_or_stepwise( epoch: int, num_train_epochs: int, global_step: int, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], mmdit: sd3_models.MMDiT, vae: sd3_models.SDVAE, ): @@ -208,23 +182,75 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", ) parser.add_argument( - "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + "--save_clip", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", ) parser.add_argument( - "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + "--save_t5xxl", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", ) parser.add_argument( "--t5xxl_device", type=str, default=None, - help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", ) parser.add_argument( "--t5xxl_dtype", type=str, default=None, - help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=256, + help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256", + ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) + parser.add_argument( + "--enable_scaled_pos_embed", + action="store_true", + help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M" + " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) # copy from Diffusers @@ -233,16 +259,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd", ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") parser.add_argument( "--mode_scale", type=float, default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) @@ -283,7 +318,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # temporary copied from sd3_minimal_inferece.py -def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): +def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): start = sampling.timestep(sampling.sigma_max) end = sampling.timestep(sampling.sigma_min) timesteps = torch.linspace(start, end, steps) @@ -319,6 +354,8 @@ def do_sample( # noise = get_noise(seed, latent).to(device) if seed is not None: generator = torch.manual_seed(seed) + else: + generator = None noise = ( torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") .to(latent.dtype) @@ -327,7 +364,7 @@ def do_sample( model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 - sigmas = get_sigmas(model_sampling, steps).to(device) + sigmas = get_all_sigmas(model_sampling, steps).to(device) noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) @@ -337,71 +374,42 @@ def do_sample( x = noise_scaled.to(device).to(dtype) # print(x.shape) - with torch.no_grad(): - for i in tqdm(range(len(sigmas) - 1)): - sigma_hat = sigmas[i] + # with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] - timestep = model_sampling.timestep(sigma_hat).float() - timestep = torch.FloatTensor([timestep, timestep]).to(device) + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) - x_c_nc = torch.cat([x, x], dim=0) - # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) - model_output = model_output.float() - batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + mmdit.prepare_block_swap_before_forward() + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) - pos_out, neg_out = batched.chunk(2) - denoised = neg_out + (pos_out - neg_out) * guidance_scale - # print(denoised.shape) + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) - # d = to_d(x, sigma_hat, denoised) - dims_to_append = x.ndim - sigma_hat.ndim - sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] - # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) - """Converts a denoiser output to a Karras ODE derivative.""" - d = (x - denoised) / sigma_hat_dims + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims - dt = sigmas[i + 1] - sigma_hat + dt = sigmas[i + 1] - sigma_hat - # Euler method - x = x + d * dt - x = x.to(dtype) + # Euler method + x = x + d * dt + x = x.to(dtype) + mmdit.prepare_block_swap_before_forward() return x -def load_prompts(prompt_file: str) -> List[Dict]: - # read prompts - if prompt_file.endswith(".txt"): - with open(prompt_file, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif prompt_file.endswith(".toml"): - with open(prompt_file, "r", encoding="utf-8") as f: - data = toml.load(f) - prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] - elif prompt_file.endswith(".json"): - with open(prompt_file, "r", encoding="utf-8") as f: - prompts = json.load(f) - - # preprocess prompts - for i in range(len(prompts)): - prompt_dict = prompts[i] - if isinstance(prompt_dict, str): - from library.train_util import line_to_prompt_dict - - prompt_dict = line_to_prompt_dict(prompt_dict) - prompts[i] = prompt_dict - assert isinstance(prompt_dict, dict) - - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) - - return prompts - - def sample_images( accelerator: Accelerator, args: argparse.Namespace, @@ -429,7 +437,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -437,10 +445,10 @@ def sample_images( # unwrap unet and text_encoder(s) mmdit = accelerator.unwrap_model(mmdit) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) @@ -453,12 +461,9 @@ def sample_images( except Exception: pass - org_vae_device = vae.device # will be on cpu - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device - if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -501,8 +506,6 @@ def sample_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) - clean_memory_on_device(accelerator.device) @@ -510,7 +513,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, mmdit: sd3_models.MMDiT, - text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], vae: sd3_models.SDVAE, save_dir, prompt_dict, @@ -562,32 +565,49 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - lg_out, t5_out, pooled = te_outputs + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # encode negative prompts - if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: - neg_te_outputs = sample_prompts_te_outputs[negative_prompt] - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) - neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - - lg_out, t5_out, pooled = neg_te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image - latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) - latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + clean_memory_on_device(accelerator.device) + with accelerator.autocast(), torch.no_grad(): + # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype. + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device) # latent to image - with torch.no_grad(): - image = vae.decode(latents) + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) @@ -609,14 +629,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # region Diffusers @@ -886,4 +901,78 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return self.config.num_train_timesteps +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + # endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 5849518f..1861dfbc 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -1,9 +1,12 @@ +from dataclasses import dataclass import math -from typing import Dict, Optional, Union +import re +from typing import Dict, List, Optional, Union import torch import safetensors from safetensors.torch import load_file from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig from .utils import setup_logging @@ -19,18 +22,62 @@ from library import sdxl_model_util # region models +# TODO remove dependency on flux_utils +from library.utils import load_safetensors +from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl -def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): - if disable_mmap: - return safetensors.torch.load(open(path, "rb").read()) + +def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): + logger.info(f"Analyzing state dict state...") + + # analyze configs + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None + + # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) + x_block_self_attn_layers = [] + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") + for key in list(state_dict.keys()): + m = re_attn.search(key) + if m: + x_block_self_attn_layers.append(int(m.group(1))) + + context_embedder_in_features = context_shape[1] + context_embedder_out_features = context_shape[0] + + # only supports 3-5-large, medium or 3-medium + if qk_norm is not None: + if len(x_block_self_attn_layers) == 0: + model_type = "3-5-large" + else: + model_type = "3-5-medium" else: - try: - return load_file(path, device=dvc) - except: - return load_file(path) # prevent device invalid Error + model_type = "3-medium" + + params = sd3_models.SD3Params( + patch_size=patch_size, + depth=depth, + num_patches=num_patches, + pos_embed_max_size=pos_embed_max_size, + adm_in_channels=adm_in_channels, + qk_norm=qk_norm, + x_block_self_attn_layers=x_block_self_attn_layers, + context_embedder_in_features=context_embedder_in_features, + context_embedder_out_features=context_embedder_out_features, + model_type=model_type, + ) + logger.info(f"Analyzed state dict state: {params}") + return params -def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): +def load_mmdit( + state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch" +) -> sd3_models.MMDiT: mmdit_sd = {} mmdit_prefix = "model.diffusion_model." @@ -40,30 +87,25 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc # load MMDiT logger.info("Building MMDit") + params = analyze_state_dict_state(mmdit_sd) with init_empty_weights(): - mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True) logger.info(f"Loaded MMDiT: {info}") return mmdit def load_clip_l( - state_dict: Dict, clip_l_path: Optional[str], - attn_mode: str, - clip_dtype: Optional[Union[str, torch.dtype]], + dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): clip_l_sd = None - if clip_l_path: - logger.info(f"Loading clip_l from {clip_l_path}...") - clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - else: + if clip_l_path is None: if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_l: remove prefix "text_encoders.clip_l." logger.info("clip_l is included in the checkpoint") @@ -72,34 +114,58 @@ def load_clip_l( for k in list(state_dict.keys()): if k.startswith(prefix): clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_l_path is None: + logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided") + return None + + # load clip_l + logger.info("Building CLIP-L") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) if clip_l_sd is None: - clip_l = None - else: - logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded ClipL: {info}") - clip_l.set_attn_mode(attn_mode) - return clip_l + logger.info(f"Loading state dict from {clip_l_path}") + clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + if "text_projection.weight" not in clip_l_sd: + logger.info("Adding text_projection.weight to clip_l_sd") + clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device) + + info = clip.load_state_dict(clip_l_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip def load_clip_g( - state_dict: Dict, clip_g_path: Optional[str], - attn_mode: str, - clip_dtype: Optional[Union[str, torch.dtype]], + dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): clip_g_sd = None - if clip_g_path: - logger.info(f"Loading clip_g from {clip_g_path}...") - clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - else: + if state_dict is not None: if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_g: remove prefix "text_encoders.clip_g." logger.info("clip_g is included in the checkpoint") @@ -108,34 +174,53 @@ def load_clip_g( for k in list(state_dict.keys()): if k.startswith(prefix): clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_g_path is None: + logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided") + return None + + # load clip_g + logger.info("Building CLIP-G") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) if clip_g_sd is None: - clip_g = None - else: - logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded ClipG: {info}") - clip_g.set_attn_mode(attn_mode) - return clip_g + logger.info(f"Loading state dict from {clip_g_path}") + clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(clip_g_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-G: {info}") + return clip def load_t5xxl( - state_dict: Dict, t5xxl_path: Optional[str], - attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): t5xxl_sd = None - if t5xxl_path: - logger.info(f"Loading t5xxl from {t5xxl_path}...") - t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: + if state_dict is not None: if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: # found t5xxl: remove prefix "text_encoders.t5xxl." logger.info("t5xxl is included in the checkpoint") @@ -144,29 +229,19 @@ def load_t5xxl( for k in list(state_dict.keys()): if k.startswith(prefix): t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + elif t5xxl_path is None: + logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided") + return None - if t5xxl_sd is None: - t5xxl = None - else: - logger.info("Building T5XXL") - - # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device - t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) - t5xxl.to(dtype=dtype) - - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded T5XXL: {info}") - t5xxl.set_attn_mode(attn_mode) - return t5xxl + return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd) def load_vae( - state_dict: Dict, vae_path: Optional[str], vae_dtype: Optional[Union[str, torch.dtype]], device: Optional[Union[str, torch.device]], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): vae_sd = {} if vae_path: @@ -181,299 +256,15 @@ def load_vae( vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) logger.info("Building VAE") - vae = sd3_models.SDVAE() + vae = sd3_models.SDVAE(vae_dtype, device) logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") - vae.to(device=device, dtype=vae_dtype) + vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype return vae -def load_models( - ckpt_path: str, - clip_l_path: str, - clip_g_path: str, - t5xxl_path: str, - vae_path: str, - attn_mode: str, - device: Union[str, torch.device], - weight_dtype: Optional[Union[str, torch.dtype]] = None, - disable_mmap: bool = False, - clip_dtype: Optional[Union[str, torch.dtype]] = None, - t5xxl_device: Optional[Union[str, torch.device]] = None, - t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, - vae_dtype: Optional[Union[str, torch.dtype]] = None, -): - """ - Load SD3 models from checkpoint files. - - Args: - ckpt_path: Path to the SD3 checkpoint file. - clip_l_path: Path to the clip_l checkpoint file. - clip_g_path: Path to the clip_g checkpoint file. - t5xxl_path: Path to the t5xxl checkpoint file. - vae_path: Path to the VAE checkpoint file. - attn_mode: Attention mode for MMDiT model. - device: Device for MMDiT model. - weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. - disable_mmap: Disable memory mapping when loading state dict. - clip_dtype: Dtype for Clip models, or None to use default dtype. - t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. - t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. - vae_dtype: Dtype for VAE model, or None to use default dtype. - - Returns: - Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. - """ - - # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. - # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. - # Therefore, we need clip_dtype and t5xxl_dtype. - - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): - if disable_mmap: - return safetensors.torch.load(open(path, "rb").read()) - else: - try: - return load_file(path, device=dvc) - except: - return load_file(path) # prevent device invalid Error - - t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or weight_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 - vae_dtype = vae_dtype or weight_dtype or torch.float32 - - logger.info(f"Loading SD3 models from {ckpt_path}...") - state_dict = load_state_dict(ckpt_path) - - # load clip_l - clip_l_sd = None - if clip_l_path: - logger.info(f"Loading clip_l from {clip_l_path}...") - clip_l_sd = load_state_dict(clip_l_path) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - else: - if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_l: remove prefix "text_encoders.clip_l." - logger.info("clip_l is included in the checkpoint") - clip_l_sd = {} - prefix = "text_encoders.clip_l." - for k in list(state_dict.keys()): - if k.startswith(prefix): - clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) - - # load clip_g - clip_g_sd = None - if clip_g_path: - logger.info(f"Loading clip_g from {clip_g_path}...") - clip_g_sd = load_state_dict(clip_g_path) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - else: - if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_g: remove prefix "text_encoders.clip_g." - logger.info("clip_g is included in the checkpoint") - clip_g_sd = {} - prefix = "text_encoders.clip_g." - for k in list(state_dict.keys()): - if k.startswith(prefix): - clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) - - # load t5xxl - t5xxl_sd = None - if t5xxl_path: - logger.info(f"Loading t5xxl from {t5xxl_path}...") - t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: - if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: - # found t5xxl: remove prefix "text_encoders.t5xxl." - logger.info("t5xxl is included in the checkpoint") - t5xxl_sd = {} - prefix = "text_encoders.t5xxl." - for k in list(state_dict.keys()): - if k.startswith(prefix): - t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) - - # MMDiT and VAE - vae_sd = {} - if vae_path: - logger.info(f"Loading VAE from {vae_path}...") - vae_sd = load_state_dict(vae_path) - else: - # remove prefix "first_stage_model." - vae_sd = {} - vae_prefix = "first_stage_model." - for k in list(state_dict.keys()): - if k.startswith(vae_prefix): - vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) - - mmdit_prefix = "model.diffusion_model." - for k in list(state_dict.keys()): - if k.startswith(mmdit_prefix): - state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) - else: - state_dict.pop(k) # remove other keys - - # load MMDiT - logger.info("Building MMDit") - with init_empty_weights(): - mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) - - logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) - logger.info(f"Loaded MMDiT: {info}") - - # load ClipG and ClipL - if clip_l_sd is None: - clip_l = None - else: - logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded ClipL: {info}") - clip_l.set_attn_mode(attn_mode) - - if clip_g_sd is None: - clip_g = None - else: - logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded ClipG: {info}") - clip_g.set_attn_mode(attn_mode) - - # load T5XXL - if t5xxl_sd is None: - t5xxl = None - else: - logger.info("Building T5XXL") - t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded T5XXL: {info}") - t5xxl.set_attn_mode(attn_mode) - - # load VAE - logger.info("Building VAE") - vae = sd3_models.SDVAE() - logger.info("Loading state dict...") - info = vae.load_state_dict(vae_sd) - logger.info(f"Loaded VAE: {info}") - vae.to(device=device, dtype=vae_dtype) - - return mmdit, clip_l, clip_g, t5xxl, vae - - # endregion -# region utils - - -def get_cond( - prompt: str, - tokenizer: sd3_models.SD3Tokenizer, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) - print(t5_tokens) - return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) - - -def get_cond_from_tokens( - l_tokens, - g_tokens, - t5_tokens, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - l_out, l_pooled = clip_l.encode_token_weights(l_tokens) - g_out, g_pooled = clip_g.encode_token_weights(g_tokens) - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - if device is not None: - lg_out = lg_out.to(device=device) - l_pooled = l_pooled.to(device=device) - g_pooled = g_pooled.to(device=device) - if dtype is not None: - lg_out = lg_out.to(dtype=dtype) - l_pooled = l_pooled.to(dtype=dtype) - g_pooled = g_pooled.to(dtype=dtype) - - # t5xxl may be in another device (eg. cpu) - if t5_tokens is None: - t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) - else: - t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None - if device is not None: - t5_out = t5_out.to(device=device) - if dtype is not None: - t5_out = t5_out.to(dtype=dtype) - - # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) - return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) - - -# used if other sd3 models is available -r""" -def get_sd3_configs(state_dict: Dict): - # Important configuration values can be quickly determined by checking shapes in the source file - # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) - # prefix = "model.diffusion_model." - prefix = "" - - patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] - depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 - num_patches = state_dict[prefix + "pos_embed"].shape[1] - pos_embed_max_size = round(math.sqrt(num_patches)) - adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] - context_shape = state_dict[prefix + "context_embedder.weight"].shape - context_embedder_config = { - "target": "torch.nn.Linear", - "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, - } - return { - "patch_size": patch_size, - "depth": depth, - "num_patches": num_patches, - "pos_embed_max_size": pos_embed_max_size, - "adm_in_channels": adm_in_channels, - "context_embedder": context_embedder_config, - } - - -def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): - "" - Doesn't load state dict. - "" - sd3_configs = get_sd3_configs(state_dict) - - mmdit = sd3_models.MMDiT( - input_size=None, - pos_embed_max_size=sd3_configs["pos_embed_max_size"], - patch_size=sd3_configs["patch_size"], - in_channels=16, - adm_in_channels=sd3_configs["adm_in_channels"], - depth=sd3_configs["depth"], - mlp_ratio=4, - qk_norm=None, - num_patches=sd3_configs["num_patches"], - context_size=4096, - attn_mode=attn_mode, - ) - return mmdit -""" class ModelSamplingDiscreteFlow: @@ -509,6 +300,3 @@ class ModelSamplingDiscreteFlow: # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image - - -# endregion diff --git a/library/strategy_base.py b/library/strategy_base.py index e390c5f3..358e42f1 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -518,7 +518,7 @@ class LatentsCachingStrategy: self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ - for SD/SDXL/SD3.0 + for SD/SDXL """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 0b0c34af..5e65927f 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -190,6 +190,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) @@ -211,7 +212,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] @@ -225,7 +226,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): vae_dtype = vae.dtype self._default_cache_batch_latents( - encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True ) if not train_util.HIGH_VRAM: diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 9fde0208..1d55fe21 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -1,9 +1,10 @@ import os import glob +import random from typing import Any, List, Optional, Tuple, Union import torch import numpy as np -from transformers import CLIPTokenizer, T5TokenizerFast +from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel from library import sd3_utils, train_util from library import sd3_models @@ -48,45 +49,200 @@ class Sd3TokenizeStrategy(TokenizeStrategy): class Sd3TextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_lg_attn_mask: bool = False, - apply_t5_attn_mask: bool = False, + apply_lg_attn_mask: Optional[bool] = False, + apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, ) -> List[torch.Tensor]: """ returned embeddings are not masked """ clip_l, clip_g, t5xxl = models + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] - l_tokens, g_tokens, t5_tokens = tokens[:3] - l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] - if l_tokens is None: + if apply_lg_attn_mask is None: + apply_lg_attn_mask = self.apply_lg_attn_mask + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens + + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None + lg_pooled = None + l_attn_mask = None + g_attn_mask = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - l_out, l_pooled = clip_l(l_tokens) - g_out, g_pooled = clip_g(g_tokens) - if apply_lg_attn_mask: - l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) - g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) + + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] + + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) + + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out + else: + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out + else: + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) lg_out = torch.cat([l_out, g_out], dim=-1) - if t5xxl is not None and t5_tokens is not None: - t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) - else: + if t5xxl is None or t5_tokens is None: t5_out = None + t5_attn_mask = None + else: + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None - return [lg_out, t5_out, lg_pooled] + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out + else: + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask + + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor @@ -132,39 +288,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False - # t5xxl is optional + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: - l_out = lg_out[..., :768] - g_out = lg_out[..., 768:] # 1280 - l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. - g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. - return np.concatenate([l_out, g_out], axis=-1) - - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] - t5_out = data["t5_out"] if "t5_out" in data else None + t5_out = data["t5_out"] - if self.apply_lg_attn_mask: - l_attn_mask = data["clip_l_attn_mask"] - g_attn_mask = data["clip_g_attn_mask"] - lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + t5_attn_mask = data["t5_attn_mask"] - if self.apply_t5_attn_mask and t5_out is not None: - t5_attn_mask = data["t5_attn_mask"] - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) - - return [lg_out, t5_out, lg_pooled] + # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -174,46 +329,56 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask + # always disable dropout during caching + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, ) if lg_out.dtype == torch.bfloat16: lg_out = lg_out.float() if lg_pooled.dtype == torch.bfloat16: lg_pooled = lg_pooled.float() - if t5_out is not None and t5_out.dtype == torch.bfloat16: + if t5_out.dtype == torch.bfloat16: t5_out = t5_out.float() lg_out = lg_out.cpu().numpy() lg_pooled = lg_pooled.cpu().numpy() - if t5_out is not None: - t5_out = t5_out.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = tokens_and_masks[3].cpu().numpy() + g_attn_mask = tokens_and_masks[4].cpu().numpy() + t5_attn_mask = tokens_and_masks[5].cpu().numpy() for i, info in enumerate(infos): lg_out_i = lg_out[i] - t5_out_i = t5_out[i] if t5_out is not None else None + t5_out_i = t5_out[i] lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask if self.cache_to_disk: - clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] - clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() - clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None - kwargs = {} - if t5_out is not None: - kwargs["t5_out"] = t5_out_i np.savez( info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, - clip_l_attn_mask=clip_l_attn_mask_i, - clip_g_attn_mask=clip_g_attn_mask_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, t5_attn_mask=t5_attn_mask_i, - **kwargs, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, ) else: - info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) class Sd3LatentsCachingStrategy(LatentsCachingStrategy): @@ -234,7 +399,12 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -242,45 +412,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) - - -if __name__ == "__main__": - # test code for Sd3TokenizeStrategy - # tokenizer = sd3_models.SD3Tokenizer() - strategy = Sd3TokenizeStrategy(256) - text = "hello world" - - l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) - # print(l_tokens.shape) - print(l_tokens) - print(g_tokens) - print(t5_tokens) - - texts = ["hello world", "the quick brown fox jumps over the lazy dog"] - l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - t5_tokens_2 = strategy.t5xxl( - texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - print(l_tokens_2) - print(g_tokens_2) - print(t5_tokens_2) - - # compare - print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) - print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) - print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) - - text = ",".join(["hello world! this is long text"] * 50) - l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) - print(l_tokens) - print(g_tokens) - print(t5_tokens) - - print(f"model max length l: {strategy.clip_l.model_max_length}") - print(f"model max length g: {strategy.clip_g.model_max_length}") - print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/train_util.py b/library/train_util.py index 9595dfc3..18d3cf6c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1082,6 +1082,10 @@ class BaseDataset(torch.utils.data.Dataset): info.image = info.image.result() # future to image caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + # remove image from memory + for info in batch: + info.image = None + # define ThreadPoolExecutor to load images in parallel max_workers = min(os.cpu_count(), len(image_infos)) max_workers = max(1, max_workers // num_processes) # consider multi-gpu @@ -1397,7 +1401,17 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): - return imagesize.get(image_path) + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use cv2 + img = cv2.imread(image_path) + if img is not None: + image_size = (img.shape[1], img.shape[0]) + else: + logger.warning(f"failed to get image size: {image_path}") + image_size = (0, 0) + return image_size def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask) @@ -1615,7 +1629,6 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) @@ -2511,6 +2524,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.verify_bucket_reso_steps(min_steps) + def get_resolutions(self) -> List[Tuple[int, int]]: + return [(dataset.width, dataset.height) for dataset in self.datasets] + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -5963,6 +5979,37 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + def sample_images_common( pipe_class, accelerator: Accelerator, diff --git a/library/utils.py b/library/utils.py index 8a0c782c..ca0f904d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest import cv2 from PIL import Image import numpy as np +from safetensors.torch import load_file def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() +# region Logging + + def add_logging_arguments(parser): parser.add_argument( "--console_log_level", @@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +# endregion + +# region PyTorch utils + + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ Convert a string to a torch.dtype @@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen: # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +) -> dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + + +# endregion + +# region Image utils + + def pil_resize(image, size, interpolation=Image.LANCZOS): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False @@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 +# endregion + # TODO make inf_utils.py - - # region Gradual Latent hires fix diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 00000000..efe20245 --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,839 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import numpy as np +import torch +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from networks.lora_flux import LoRAModule, LoRAInfModule +from library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 630da7e0..86dba246 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -10,10 +10,13 @@ import numpy as np import torch from safetensors.torch import safe_open, load_file +import torch.amp from tqdm import tqdm from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel from library.device_utils import init_ipex, get_preferred_device +from networks import lora_sd3 init_ipex() @@ -25,11 +28,14 @@ import logging logger = logging.getLogger(__name__) from library import sd3_models, sd3_utils, strategy_sd3 +from library.utils import load_safetensors -def get_noise(seed, latent): - generator = torch.manual_seed(seed) - return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) +def get_noise(seed, latent, device="cpu"): + # generator = torch.manual_seed(seed) + generator = torch.Generator(device) + generator.manual_seed(seed) + return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device) def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): @@ -59,7 +65,7 @@ def do_sample( neg_cond: Tuple[torch.Tensor, torch.Tensor], mmdit: sd3_models.MMDiT, steps: int, - guidance_scale: float, + cfg_scale: float, dtype: torch.dtype, device: str, ): @@ -71,7 +77,7 @@ def do_sample( latent = latent.to(dtype).to(device) - noise = get_noise(seed, latent).to(device) + noise = get_noise(seed, latent, device) model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 @@ -100,12 +106,13 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + with torch.autocast(device_type=device.type, dtype=dtype): + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) pos_out, neg_out = batched.chunk(2) - denoised = neg_out + (pos_out - neg_out) * guidance_scale + denoised = neg_out + (pos_out - neg_out) * cfg_scale # print(denoised.shape) # d = to_d(x, sigma_hat, denoised) @@ -122,230 +129,68 @@ def do_sample( x = x.to(dtype) latent = x - scale_factor = 1.5305 - shift_factor = 0.0609 - # def process_out(self, latent): - # return (latent / self.scale_factor) + self.shift_factor - latent = (latent / scale_factor) + shift_factor + latent = vae.process_out(latent) return latent -if __name__ == "__main__": - target_height = 1024 - target_width = 1024 - - # steps = 50 # 28 # 50 - guidance_scale = 5 - # seed = 1 # None # 1 - - device = get_preferred_device() - - parser = argparse.ArgumentParser() - parser.add_argument("--ckpt_path", type=str, required=True) - parser.add_argument("--clip_g", type=str, required=False) - parser.add_argument("--clip_l", type=str, required=False) - parser.add_argument("--t5xxl", type=str, required=False) - parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") - parser.add_argument("--apply_lg_attn_mask", action="store_true") - parser.add_argument("--apply_t5_attn_mask", action="store_true") - parser.add_argument("--prompt", type=str, default="A photo of a cat") - # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders - parser.add_argument("--negative_prompt", type=str, default="") - parser.add_argument("--output_dir", type=str, default=".") - parser.add_argument("--do_not_use_t5xxl", action="store_true") - parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") - parser.add_argument("--fp16", action="store_true") - parser.add_argument("--bf16", action="store_true") - parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--steps", type=int, default=50) - # parser.add_argument( - # "--lora_weights", - # type=str, - # nargs="*", - # default=[], - # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", - # ) - # parser.add_argument("--interactive", action="store_true") - args = parser.parse_args() - - seed = args.seed - steps = args.steps - - sd3_dtype = torch.float32 - if args.fp16: - sd3_dtype = torch.float16 - elif args.bf16: - sd3_dtype = torch.bfloat16 - - # TODO test with separated safetenors files for each model - - # load state dict - logger.info(f"Loading SD3 models from {args.ckpt_path}...") - state_dict = load_file(args.ckpt_path) - - if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_g: remove prefix "text_encoders.clip_g." - logger.info("clip_g is included in the checkpoint") - clip_g_sd = {} - prefix = "text_encoders.clip_g." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info(f"Lodaing clip_g from {args.clip_g}...") - clip_g_sd = load_file(args.clip_g) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - - if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_l: remove prefix "text_encoders.clip_l." - logger.info("clip_l is included in the checkpoint") - clip_l_sd = {} - prefix = "text_encoders.clip_l." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info(f"Lodaing clip_l from {args.clip_l}...") - clip_l_sd = load_file(args.clip_l) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - - if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: - # found t5xxl: remove prefix "text_encoders.t5xxl." - logger.info("t5xxl is included in the checkpoint") - if not args.do_not_use_t5xxl: - t5xxl_sd = {} - prefix = "text_encoders.t5xxl." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info("but not used") - for key in list(state_dict.keys()): - if key.startswith("text_encoders.t5xxl."): - state_dict.pop(key) - t5xxl_sd = None - elif args.t5xxl: - assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" - logger.info(f"Lodaing t5xxl from {args.t5xxl}...") - t5xxl_sd = load_file(args.t5xxl) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: - logger.info("t5xxl is not used") - t5xxl_sd = None - - use_t5xxl = t5xxl_sd is not None - - # MMDiT and VAE - vae_sd = {} - vae_prefix = "first_stage_model." - mmdit_prefix = "model.diffusion_model." - for k, v in list(state_dict.items()): - if k.startswith(vae_prefix): - vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) - elif k.startswith(mmdit_prefix): - state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) - - # load tokenizers - logger.info("Loading tokenizers...") - tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) - - # load models - # logger.info("Create MMDiT from SD3 checkpoint...") - # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) - logger.info("Create MMDiT") - mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) - - logger.info("Loading state dict...") - info = mmdit.load_state_dict(state_dict) - logger.info(f"Loaded MMDiT: {info}") - - logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") - mmdit.to(device, dtype=sd3_dtype) - mmdit.eval() - - # load VAE - logger.info("Create VAE") - vae = sd3_models.SDVAE() - logger.info("Loading state dict...") - info = vae.load_state_dict(vae_sd) - logger.info(f"Loaded VAE: {info}") - - logger.info(f"Move VAE to {device} and {sd3_dtype}...") - vae.to(device, dtype=sd3_dtype) - vae.eval() - - # load text encoders - logger.info("Create clip_l") - clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) - - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded clip_l: {info}") - - logger.info(f"Move clip_l to {device} and {sd3_dtype}...") - clip_l.to(device, dtype=sd3_dtype) - clip_l.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - clip_l.set_attn_mode(args.attn_mode) - - logger.info("Create clip_g") - clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) - - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded clip_g: {info}") - - logger.info(f"Move clip_g to {device} and {sd3_dtype}...") - clip_g.to(device, dtype=sd3_dtype) - clip_g.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - clip_g.set_attn_mode(args.attn_mode) - - if use_t5xxl: - logger.info("Create t5xxl") - t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) - - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded t5xxl: {info}") - - logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") - t5xxl.to(device, dtype=sd3_dtype) - # t5xxl.to("cpu", dtype=torch.float32) # run on CPU - t5xxl.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - t5xxl.set_attn_mode(args.attn_mode) - else: - t5xxl = None - +def generate_image( + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: CLIPTextModelWithProjection, + clip_g: CLIPTextModelWithProjection, + t5xxl: T5EncoderModel, + steps: int, + prompt: str, + seed: int, + target_width: int, + target_height: int, + device: str, + negative_prompt: str, + cfg_scale: float, +): # prepare embeddings logger.info("Encoding prompts...") - encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - tokens_and_masks = tokenize_strategy.tokenize(args.prompt) - lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask - ) - cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + # TODO support one-by-one offloading + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) - tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) - lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask - ) - neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad(): + tokens_and_masks = tokenize_strategy.tokenize(prompt) + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # attn masks are not used currently + + if args.offload: + clip_l.to("cpu") + clip_g.to("cpu") + t5xxl.to("cpu") # generate image logger.info("Generating image...") - latent_sampled = do_sample( - target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device - ) + mmdit.to(device) + latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device) + if args.offload: + mmdit.to("cpu") # latent to image + vae.to(device) with torch.no_grad(): image = vae.decode(latent_sampled) + + if args.offload: + vae.to("cpu") + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) @@ -359,3 +204,204 @@ if __name__ == "__main__": out_image.save(output_path) logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + # cfg_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--cfg_scale", type=float, default=5.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument("--output_dir", type=str, default=".") + # parser.add_argument("--do_not_use_t5xxl", action="store_true") + # parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + loading_device = "cpu" if args.offload else device + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + # state_dict = load_file(args.ckpt_path) + state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype) + + # load text encoders + clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict) + + # MMDiT and VAE + vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict) + mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device) + + clip_l.to(sd3_dtype) + clip_g.to(sd3_dtype) + t5xxl.to(sd3_dtype) + vae.to(sd3_dtype) + mmdit.to(sd3_dtype) + if not args.offload: + # make sure to move to the device: some tensors are created in the constructor on the CPU + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + vae.to(device) + mmdit.to(device) + + clip_l.eval() + clip_g.eval() + t5xxl.eval() + mmdit.eval() + vae.eval() + + # load tokenizers + logger.info("Loading tokenizers...") + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + + # LoRA + lora_models: list[lora_sd3.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + module = lora_sd3 + lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd) + else: + lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + + if not args.interactive: + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + args.steps, + args.prompt, + args.seed, + args.width, + args.height, + device, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = args.width + height = args.height + steps = None + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d " + " --n , `--n -` for empty negative prompt" + "Options are kept for the next prompt. Current options:" + f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}" + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + steps if steps is not None else args.steps, + prompt, + seed if seed is not None else args.seed, + width, + height, + device, + negative_prompt if negative_prompt is not None else args.negative_prompt, + cfg_scale, + ) + + logger.info("Done!") diff --git a/sd3_train.py b/sd3_train.py index ef18c32c..f64e2da2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -1,6 +1,7 @@ # training with captions import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os @@ -11,6 +12,7 @@ import toml from tqdm import tqdm import torch +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -38,7 +40,7 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments # from library.custom_train_functions import ( # apply_snr_weight, @@ -61,22 +63,17 @@ def train(args): if not args.skip_cache_check: args.skip_cache_check = args.skip_latents_validity_check - assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - - # # training text encoder is not supported - # assert ( - # not args.train_text_encoder - # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - - # # training without text encoder cache is not supported: because T5XXL must be cached - # assert ( - # args.cache_text_encoder_outputs - # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" @@ -90,13 +87,13 @@ def train(args): ) args.cache_text_encoder_outputs = True - # if args.block_lr: - # block_lrs = [float(lr) for lr in args.block_lr.split(",")] - # assert ( - # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR - # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" - # else: - # block_lrs = None + if args.train_t5xxl: + assert ( + args.train_text_encoder + ), "when training T5XXL, text encoder (CLIP-L/G) must be trained / T5XXLを学習するときはtext encoder (CLIP-L/G)も学習する必要があります" + assert ( + not args.cache_text_encoder_outputs + ), "when training T5XXL, t5xxl output must not be cached / T5XXLを学習するときはt5xxlの出力をキャッシュできません" cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -111,11 +108,6 @@ def train(args): ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) - # load tokenizer and prepare tokenize strategy - sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) - sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) - strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) - # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -156,10 +148,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -205,41 +197,151 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 - - t5xxl_dtype = weight_dtype - if args.t5xxl_dtype is not None: - if args.t5xxl_dtype == "fp16": - t5xxl_dtype = torch.float16 - elif args.t5xxl_dtype == "bf16": - t5xxl_dtype = torch.bfloat16 - elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": - t5xxl_dtype = torch.float32 - else: - raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") - t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - - clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む - attn_mode = "xformers" if args.xformers else "torch" - assert ( - attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + # t5xxl_dtype = weight_dtype + # if args.t5xxl_dtype is not None: + # if args.t5xxl_dtype == "fp16": + # t5xxl_dtype = torch.float16 + # elif args.t5xxl_dtype == "bf16": + # t5xxl_dtype = torch.bfloat16 + # elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + # t5xxl_dtype = torch.float32 + # else: + # raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + # t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + # clip_dtype = weight_dtype # if not args.train_text_encoder else None - # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. - logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") - device_to_load = accelerator.device if args.lowram else "cpu" - sd3_state_dict = sd3_utils.load_safetensors( - args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors + # if clip_l is not specified, the checkpoint must contain clip_l, so we load state dict here + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + if args.clip_l is None: + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + else: + sd3_state_dict = None + + # load tokenizer and prepare tokenize strategy + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) + + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + # clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" + + # prepare text encoding strategy + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate ) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # 学習を準備する:モデルを適切な状態にする + train_clip = False + train_t5xxl = False + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + if args.train_t5xxl: + t5xxl.gradient_checkpointing_enable() + + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + lr_t5xxl = args.learning_rate_te3 if args.learning_rate_te3 is not None else args.learning_rate # 0 means not train + train_clip = lr_te1 != 0 or lr_te2 != 0 + train_t5xxl = lr_t5xxl != 0 and args.train_t5xxl + + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(train_clip) + clip_g.requires_grad_(train_clip) + t5xxl.requires_grad_(train_t5xxl) + else: + print("disable text encoder training") + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + t5xxl.requires_grad_(False) + lr_te1 = 0 + lr_te2 = 0 + lr_t5xxl = 0 + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + t5xxl.to(accelerator.device) + clip_l.eval() + clip_g.eval() + t5xxl.eval() + + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + enable_dropout=False, + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + if not args.use_t5xxl_cache_only: + clip_l = None + clip_g = None + t5xxl = None + + clean_memory_on_device(accelerator.device) # load VAE for caching latents - vae: sd3_models.SDVAE = None + if sd3_state_dict is None: + logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + + vae = sd3_utils.load_vae(args.vae, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) if cache_latents: - vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) - vae.to(accelerator.device, dtype=vae_dtype) + # vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.to(accelerator.device, dtype=weight_dtype) vae.requires_grad_(False) vae.eval() @@ -250,127 +352,45 @@ def train(args): accelerator.wait_for_everyone() - # load clip_l, clip_g, t5xxl for caching text encoder outputs - # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype - # ) - clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) - clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - - t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - - # should be deleted after caching text encoder outputs when not training text encoder - # this strategy should not be used other than this process - text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - - # 学習を準備する:モデルを適切な状態にする - train_clip_l = False - train_clip_g = False - train_t5xxl = False - - if args.train_text_encoder: - accelerator.print("enable text encoder training") - if args.gradient_checkpointing: - clip_l.gradient_checkpointing_enable() - clip_g.gradient_checkpointing_enable() - lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_clip_l = lr_te1 != 0 - train_clip_g = lr_te2 != 0 - - if not train_clip_l: - clip_l.to(weight_dtype) - if not train_clip_g: - clip_g.to(weight_dtype) - clip_l.requires_grad_(train_clip_l) - clip_g.requires_grad_(train_clip_g) - clip_l.train(train_clip_l) - clip_g.train(train_clip_g) - else: - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) - clip_l.requires_grad_(False) - clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() - - if t5xxl is not None: - t5xxl.to(t5xxl_dtype) - t5xxl.requires_grad_(False) - t5xxl.eval() - - # cache text encoder outputs - sample_prompts_te_outputs = None - if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad here - clip_l.to(accelerator.device) - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(t5xxl_device) - - text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - train_clip_g or train_clip_l or args.use_t5xxl_cache_only, - args.apply_lg_attn_mask, - args.apply_t5_attn_mask, - ) - strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) - - clip_l.to(accelerator.device, dtype=weight_dtype) - clip_g.to(accelerator.device, dtype=weight_dtype) - if t5xxl is not None: - t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) - - with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) - - # cache sample prompt's embeddings to free text encoder's memory - if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - prompts = sd3_train_utils.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs - with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: - if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_list = sd3_tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, - [clip_l, clip_g, t5xxl], - tokens_list, - args.apply_lg_attn_mask, - args.apply_t5_attn_mask, - ) - - accelerator.wait_for_everyone() - # load MMDIT - # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). - # by loading with model_dtype, we can reduce memory usage. - model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) - mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu") + + # attn_mode = "xformers" if args.xformers else "torch" + # assert ( + # attn_mode == "torch" + # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + resolutions = train_dataset_group.get_resolutions() + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() train_mmdit = args.learning_rate != 0 mmdit.requires_grad_(train_mmdit) if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared + + # block swap + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap) if not cache_latents: - # load VAE here if not cached - vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + # move to accelerator device vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) + vae.to(accelerator.device, dtype=weight_dtype) mmdit.requires_grad_(train_mmdit) if not train_mmdit: @@ -394,19 +414,24 @@ def train(args): training_models = [] params_to_optimize = [] - # if train_unet: + param_names = [] training_models.append(mmdit) - # if block_lrs is None: params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) - # else: - # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) + param_names.append([n for n, _ in mmdit.named_parameters()]) - # if train_clip_l: - # training_models.append(clip_l) - # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) - # if train_clip_g: - # training_models.append(clip_g) - # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + if train_clip: + if lr_te1 > 0: + training_models.append(clip_l) + params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + param_names.append([n for n, _ in clip_l.named_parameters()]) + if lr_te2 > 0: + training_models.append(clip_g) + params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + param_names.append([n for n, _ in clip_g.named_parameters()]) + if train_t5xxl: + training_models.append(t5xxl) + params_to_optimize.append({"params": list(t5xxl.parameters()), "lr": args.learning_rate_te3 or args.learning_rate}) + param_names.append([n for n, _ in t5xxl.named_parameters()]) # calculate number of trainable parameters n_params = 0 @@ -414,47 +439,49 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit} , clip:{train_clip}, t5xxl:{train_t5xxl}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups for mmdit. clip_l, clip_g, t5xxl are in each group grouped_params = [] - param_group = [] - param_group_lr = -1 - for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr + param_group = {} + group = params_to_optimize[0] + named_parameters = list(mmdit.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # joint or other + if np[0].startswith("joint_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "joint" + else: + block_idx = -1 - param_group.append(p) + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + grouped_params.extend(params_to_optimize[1:]) # add clip_l, clip_g, t5xxl if they are trained # prepare optimizers for each group optimizers = [] @@ -463,10 +490,15 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -497,7 +529,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -511,18 +543,22 @@ def train(args): ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") mmdit.to(weight_dtype) - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) if t5xxl is not None: - t5xxl.to(weight_dtype) # TODO check works with fp16 or not + t5xxl.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") mmdit.to(weight_dtype) - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) if t5xxl is not None: t5xxl.to(weight_dtype) @@ -532,15 +568,8 @@ def train(args): # clip_l.text_model.encoder.layers[-1].requires_grad_(False) # clip_l.text_model.final_layer_norm.requires_grad_(False) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: + # move Text Encoders to GPU if not caching outputs + if not args.cache_text_encoder_outputs: # make sure Text Encoders are on GPU # TODO support CPU for text encoders clip_l.to(accelerator.device) @@ -548,18 +577,11 @@ def train(args): if t5xxl is not None: t5xxl.to(accelerator.device) - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory clean_memory_on_device(accelerator.device) if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( - args, - mmdit=mmdit, - clip_l=clip_l if train_clip_l else None, - clip_g=clip_g if train_clip_g else None, + args, mmdit=mmdit, clip_l=clip_l if train_clip else None, clip_g=clip_g if train_clip else None ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -570,11 +592,14 @@ def train(args): else: # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: - mmdit = accelerator.prepare(mmdit) - if train_clip_l: + mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + if train_clip: clip_l = accelerator.prepare(clip_l) - if train_clip_g: clip_g = accelerator.prepare(clip_g) + if train_t5xxl: + t5xxl = accelerator.prepare(t5xxl) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -586,24 +611,112 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, device): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") + block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) + torch.cuda.synchronize() + # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") + return bidx_to_cpu, bidx_to_cuda + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + future = futures.pop(block_idx) + future.result() + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: + + blocks_to_swap = args.blocks_to_swap + num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) + handled_block_indices = set() + + n = 1 # only asynchronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: + grad_hook = None - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + if blocks_to_swap: + is_block = param_name.startswith("joint_blocks") + if is_block: + block_idx = int(param_name.split(".")[1]) + if block_idx not in handled_block_indices: + # swap following (already backpropagated) block + handled_block_indices.add(block_idx) - parameter.register_post_accumulate_grad_hook(__grad_hook) + # if n blocks were already backpropagated + num_blocks_propagated = num_blocks - block_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_blocks - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_idx - 1 - elif args.fused_optimizer_groups: + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, bidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + mmdit.joint_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) + + return __grad_hook + + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, block_idx, swapping, waiting + ) + + if grad_hook is None: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + grad_hook = __grad_hook + + parameter.register_post_accumulate_grad_hook(grad_hook) + + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -618,22 +731,59 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + blocks_to_swap = args.blocks_to_swap + num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) + + n = 1 # only asynchronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + block_type, block_idx = block_types_and_indices[opt_idx] - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) - parameter.register_post_accumulate_grad_hook(optimizer_hook) + # swap blocks if necessary + if blocks_to_swap and btype == "joint": + num_blocks_propagated = num_blocks - bidx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = bidx > 0 and bidx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_blocks - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + mmdit.joint_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = bidx - 1 + wait_blocks_move(block_idx_to_wait, futures) + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -661,17 +811,9 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = DDPMScheduler( - # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - # ) - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - # if args.zero_terminal_snr: - # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: @@ -684,60 +826,23 @@ def train(args): init_kwargs=init_kwargs, ) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + # For --sample_at_first + optimizer_eval_fn() sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - # following function will be moved to sd3_train_utils - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None - ): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting + # show model device and dtype + logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") + logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") + logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") + logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") + logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 @@ -751,16 +856,16 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + latents = vae.encode(batch["images"]) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -772,7 +877,8 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: - lg_out, t5_out, lg_pooled = text_encoder_outputs_list + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list if args.use_t5xxl_cache_only: lg_out = None lg_pooled = None @@ -780,16 +886,19 @@ def train(args): lg_out = None t5_out = None lg_pooled = None + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None - if lg_out is None or (train_clip_l or train_clip_g): + if lg_out is None: # not cached or training, so get from text encoders input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] - with torch.set_grad_enabled(args.train_text_encoder): + with torch.set_grad_enabled(train_clip): # TODO support weighted captions # text models in sd3_models require "cpu" for input_ids input_ids_clip_l = input_ids_clip_l.to("cpu") input_ids_clip_g = input_ids_clip_g.to("cpu") - lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], @@ -797,9 +906,9 @@ def train(args): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] - with torch.no_grad(): + with torch.set_grad_enabled(train_t5xxl): input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None - _, t5_out, _ = text_encoding_strategy.encode_tokens( + _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) @@ -811,21 +920,10 @@ def train(args): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # debug: NaN check for all inputs if torch.any(torch.isnan(noisy_model_input)): @@ -840,6 +938,7 @@ def train(args): # call model with accelerator.autocast(): + # TODO support attention mask model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. @@ -848,21 +947,34 @@ def train(args): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = latents - # Compute regular loss. TODO simplify this - loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), - 1, + # # Compute regular loss. TODO simplify this + # loss = torch.mean( + # (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + # 1, + # ) + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + if weighting is not None: + loss = loss * weighting + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights loss = loss.mean() accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -875,7 +987,7 @@ def train(args): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -884,6 +996,7 @@ def train(args): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() sd3_train_utils.sample_images( accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs ) @@ -900,12 +1013,13 @@ def train(args): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(clip_l) if args.save_clip else None, - accelerator.unwrap_model(clip_g) if args.save_clip else None, - accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, - accelerator.unwrap_model(mmdit), + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -928,6 +1042,7 @@ def train(args): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( @@ -938,10 +1053,10 @@ def train(args): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(clip_l) if args.save_clip else None, - accelerator.unwrap_model(clip_g) if args.save_clip else None, - accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, - accelerator.unwrap_model(mmdit), + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) @@ -958,6 +1073,7 @@ def train(args): t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) @@ -970,10 +1086,10 @@ def train(args): save_dtype, epoch, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + clip_l if train_clip else None, + clip_g if train_clip else None, + t5xxl if train_t5xxl else None, + mmdit if train_mmdit else None, vae, ) logger.info("model saved.") @@ -991,46 +1107,35 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) parser.add_argument( "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" ) - # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") parser.add_argument( "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" ) - parser.add_argument( - "--t5xxl_max_token_length", - type=int, - default=None, - help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", - ) - parser.add_argument( - "--apply_lg_attn_mask", - action="store_true", - help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", - ) - parser.add_argument( - "--apply_t5_attn_mask", - action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", - ) - # TE training is disabled temporarily - # parser.add_argument( - # "--learning_rate_te1", - # type=float, - # default=None, - # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", - # ) - # parser.add_argument( - # "--learning_rate_te2", - # type=float, - # default=None, - # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", - # ) + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) + parser.add_argument( + "--learning_rate_te3", + type=float, + default=None, + help="learning rate for text encoder 3 (T5-XXL) / text encoder 3 (T5-XXL)の学習率", + ) # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" @@ -1047,11 +1152,16 @@ def setup_parser() -> argparse.ArgumentParser: # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", # ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) parser.add_argument( "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="[DOES NOT WORK] number of optimizer groups for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizerグループ数", ) parser.add_argument( "--skip_latents_validity_check", @@ -1059,9 +1169,14 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) parser.add_argument( - "--skip_cache_check", - action="store_true", - help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) parser.add_argument( "--num_last_block_to_freeze", diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 00000000..0739e094 --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,451 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library import sd3_models, strategy_sd3, utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + + def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + # super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + # enumerate resolutions from dataset for positional embeddings + self.resolutions = train_dataset_group.get_resolutions() + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5_dropout_rate, + ) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + sd3_train_utils.sample_images( + accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs + ) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # shift 3.0 is the default value + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return sd3_models.SDVAE.process_in(latents) + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sd3_train_utils.add_sd3_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = Sd3NetworkTrainer() + trainer.train(args) diff --git a/train_network.py b/train_network.py index 9943b60b..b90aa420 100644 --- a/train_network.py +++ b/train_network.py @@ -129,6 +129,7 @@ class NetworkTrainer: def get_models_for_text_encoding(self, args, accelerator, text_encoders): """ Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). """ return text_encoders @@ -271,6 +272,9 @@ class NetworkTrainer: def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + # endregion def train(self, args): @@ -591,6 +595,7 @@ class NetworkTrainer: # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) @@ -1028,9 +1033,9 @@ class NetworkTrainer: # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: - on_step_start = lambda *args, **kwargs: None + on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -1111,7 +1116,10 @@ class NetworkTrainer: continue with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -1143,7 +1151,9 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: