From 6f3f701d3d6a4c1386ae770a3bbe997a14bbdef6 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 18 Jan 2024 18:23:21 +0200 Subject: [PATCH 01/13] Deduplicate ipex initialization code --- XTI_hijack.py | 10 +++------- fine_tune.py | 9 ++------- gen_img_diffusers.py | 9 ++------- library/ipex_interop.py | 24 ++++++++++++++++++++++++ library/model_util.py | 10 ++-------- sdxl_gen_img.py | 9 ++------- sdxl_minimal_inference.py | 12 +++++------- sdxl_train.py | 9 ++------- sdxl_train_control_net_lllite.py | 12 +++++------- sdxl_train_control_net_lllite_old.py | 12 +++++------- sdxl_train_network.py | 9 ++------- sdxl_train_textual_inversion.py | 10 +++------- train_controlnet.py | 9 ++------- train_db.py | 9 ++------- train_network.py | 9 ++------- train_textual_inversion.py | 9 ++------- train_textual_inversion_XTI.py | 12 +++++------- 17 files changed, 70 insertions(+), 113 deletions(-) create mode 100644 library/ipex_interop.py diff --git a/XTI_hijack.py b/XTI_hijack.py index ec084945..1dbc263a 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,7 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index be61b3d1..982dc8ae 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -11,15 +11,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847a..a207ad5a 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,15 +66,10 @@ import diffusers import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/library/ipex_interop.py b/library/ipex_interop.py new file mode 100644 index 00000000..6fe320c5 --- /dev/null +++ b/library/ipex_interop.py @@ -0,0 +1,24 @@ +import torch + + +def init_ipex(): + """ + Try to import `intel_extension_for_pytorch`, and apply + the hijacks using `library.ipex.ipex_init`. + + If IPEX is not installed, this function does nothing. + """ + try: + import intel_extension_for_pytorch as ipex # noqa + except ImportError: + return + + try: + from library.ipex import ipex_init + + if torch.xpu.is_available(): + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 1f40ce32..4361b499 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,15 +5,9 @@ import math import os import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass +init_ipex() import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab539984..0db9e340 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,15 +18,10 @@ import diffusers import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd6..15a70678 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,13 +9,11 @@ import random from einops import repeat import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index b4ce2770..a3f6f3a1 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,15 +11,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3c..7a88feb8 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,13 +14,11 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377b..b94bf5c1 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,13 +11,11 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d810ce7d..5d363280 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,15 +1,10 @@ import argparse import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f8a1d7bc..df393713 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,13 +3,9 @@ import os import regex import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab7..7b0b2bbf 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -12,15 +12,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index 14d9dff1..888cad25 100644 --- a/train_db.py +++ b/train_db.py @@ -12,15 +12,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index c2b7fbde..e1ff20c3 100644 --- a/train_network.py +++ b/train_network.py @@ -14,15 +14,10 @@ from tqdm import tqdm import torch from torch.nn.parallel import DistributedDataParallel as DDP -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1..ccf7596d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,15 +8,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b43549..7046a480 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,13 +8,11 @@ from multiprocessing import Value from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler From 9cfa68c92fc45fe6e78745d1ec5544b3d7b5cf5f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 08:46:53 +0800 Subject: [PATCH 02/13] [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057) * Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision --- library/train_util.py | 3 +++ train_network.py | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ff161fea..21e7638d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument( + "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" + ) parser.add_argument( "--ddp_timeout", type=int, diff --git a/train_network.py b/train_network.py index c2b7fbde..5f28a5e0 100644 --- a/train_network.py +++ b/train_network.py @@ -390,16 +390,36 @@ class NetworkTrainer: accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert ( + torch.__version__ >= '2.1.0' + ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != 'no' + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=( + weight_dtype + if te_weight_dtype == torch.float8_e4m3fn + else te_weight_dtype + )) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: if len(text_encoders) > 1: text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] @@ -421,9 +441,6 @@ class NetworkTrainer: if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -778,10 +795,17 @@ class NetworkTrainer: args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype ) if args.v_parameterization: From a7ef6422b658660dd4c4685397f9b56e24996eb2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 10:00:30 +0900 Subject: [PATCH 03/13] fix to work with torch 2.0 --- train_network.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/train_network.py b/train_network.py index 5f28a5e0..b1291ed1 100644 --- a/train_network.py +++ b/train_network.py @@ -393,11 +393,9 @@ class NetworkTrainer: unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram if args.fp8_base: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( - torch.__version__ >= '2.1.0' - ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" - assert ( - args.mixed_precision != 'no' + args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training.") unet_weight_dtype = torch.float8_e4m3fn @@ -407,13 +405,12 @@ class NetworkTrainer: unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) - t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=( - weight_dtype - if te_weight_dtype == torch.float8_e4m3fn - else te_weight_dtype - )) + + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: @@ -805,7 +802,14 @@ class NetworkTrainer: # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: From 1f77bb6e73573d0bde67b94c76b549590833ece9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 10:57:42 +0900 Subject: [PATCH 04/13] fix to work sample generation in fp8 ref #1057 --- library/sdxl_lpw_stable_diffusion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee405..0562e88a 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents): From 2a0f45aea951d8d2649917938f216705e4fc632a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 11:08:20 +0900 Subject: [PATCH 05/13] update readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index a505c0b3..663f52e8 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,20 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### 作業中の内容 / Work in progress + +- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! + - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. + - PyTorch 2.1 or later is required. + - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. + - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. + +- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 + - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 + - PyTorch 2.1 以降が必要です。 + - PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。 + - 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。 + ### Jan 17, 2024 / 2024/1/17: v0.8.1 - Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). From 5a1ebc4c7c21600ac773c22b7f1b7d20d383e318 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 13:10:45 +0900 Subject: [PATCH 06/13] format by black --- library/config_util.py | 949 ++++++++++++++++++++++------------------- 1 file changed, 506 insertions(+), 443 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 47868f3b..716cecff 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -1,462 +1,496 @@ import argparse from dataclasses import ( - asdict, - dataclass, + asdict, + dataclass, ) import functools import random from textwrap import dedent, indent import json from pathlib import Path + # from toolz import curry from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, + List, + Optional, + Sequence, + Tuple, + Union, ) import toml import voluptuous from voluptuous import ( - Any, - ExactSequence, - MultipleInvalid, - Object, - Required, - Schema, + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, ) from transformers import CLIPTokenizer from . import train_util from .train_util import ( - DreamBoothSubset, - FineTuningSubset, - ControlNetSubset, - DreamBoothDataset, - FineTuningDataset, - ControlNetDataset, - DatasetGroup, + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, ) def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + # TODO: inherit Params class in Subset, Dataset + @dataclass class BaseSubsetParams: - image_dir: Optional[str] = None - num_repeats: int = 1 - shuffle_caption: bool = False - caption_separator: str = ',', - keep_tokens: int = 0 - keep_tokens_separator: str = None, - color_aug: bool = False - flip_aug: bool = False - face_crop_aug_range: Optional[Tuple[float, float]] = None - random_crop: bool = False - caption_prefix: Optional[str] = None - caption_suffix: Optional[str] = None - caption_dropout_rate: float = 0.0 - caption_dropout_every_n_epochs: int = 0 - caption_tag_dropout_rate: float = 0.0 - token_warmup_min: int = 1 - token_warmup_step: float = 0 + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + @dataclass class DreamBoothSubsetParams(BaseSubsetParams): - is_reg: bool = False - class_tokens: Optional[str] = None - caption_extension: str = ".caption" + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + @dataclass class FineTuningSubsetParams(BaseSubsetParams): - metadata_file: Optional[str] = None + metadata_file: Optional[str] = None + @dataclass class ControlNetSubsetParams(BaseSubsetParams): - conditioning_data_dir: str = None - caption_extension: str = ".caption" + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + debug_dataset: bool = False + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class ControlNetDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: - params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + @dataclass class DatasetBlueprint: - is_dreambooth: bool - is_controlnet: bool - params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] - subsets: Sequence[SubsetBlueprint] + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + @dataclass class DatasetGroupBlueprint: - datasets: Sequence[DatasetBlueprint] + datasets: Sequence[DatasetBlueprint] + + @dataclass class Blueprint: - dataset_group: DatasetGroupBlueprint + dataset_group: DatasetGroupBlueprint class ConfigSanitizer: - # @curry - @staticmethod - def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: - Schema(ExactSequence([klass, klass]))(value) - return tuple(value) + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) - # @curry - @staticmethod - def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: - Schema(Any(klass, ExactSequence([klass, klass])))(value) - try: - Schema(klass)(value) - return (value, value) - except: - return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) - # subset schema - SUBSET_ASCENDABLE_SCHEMA = { - "color_aug": bool, - "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), - "flip_aug": bool, - "num_repeats": int, - "random_crop": bool, - "shuffle_caption": bool, - "keep_tokens": int, - "keep_tokens_separator": str, - "token_warmup_min": int, - "token_warmup_step": Any(float,int), - "caption_prefix": str, - "caption_suffix": str, - } - # DO means DropOut - DO_SUBSET_ASCENDABLE_SCHEMA = { - "caption_dropout_every_n_epochs": int, - "caption_dropout_rate": Any(float, int), - "caption_tag_dropout_rate": Any(float, int), - } - # DB means DreamBooth - DB_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - "class_tokens": str, - } - DB_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - "is_reg": bool, - } - # FT means FineTuning - FT_SUBSET_DISTINCT_SCHEMA = { - Required("metadata_file"): str, - "image_dir": str, - } - CN_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - } - CN_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - Required("conditioning_data_dir"): str, - } + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } - # datasets schema - DATASET_ASCENDABLE_SCHEMA = { - "batch_size": int, - "bucket_no_upscale": bool, - "bucket_reso_steps": int, - "enable_bucket": bool, - "max_bucket_reso": int, - "min_bucket_reso": int, - "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), - } + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + } - # options handled by argparse but not handled by user config - ARGPARSE_SPECIFIC_SCHEMA = { - "debug_dataset": bool, - "max_token_length": Any(None, int), - "prior_loss_weight": Any(float, int), - } - # for handling default None value of argparse - ARGPARSE_NULLABLE_OPTNAMES = [ - "face_crop_aug_range", - "resolution", - ] - # prepare map because option name may differ among argparse and user config - ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { - "train_batch_size": "batch_size", - "dataset_repeats": "num_repeats", - } + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert ( + support_dreambooth or support_finetuning or support_controlnet + ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" - self.db_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_DISTINCT_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.ft_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.FT_SUBSET_DISTINCT_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.cn_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_DISTINCT_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.db_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.db_subset_schema]}, - ) + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) - self.ft_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.ft_subset_schema]}, - ) + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) - self.cn_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.cn_subset_schema]}, - ) + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) - if support_dreambooth and support_finetuning: - def validate_flex_dataset(dataset_config: dict): - subsets_config = dataset_config.get("subsets", []) + if support_dreambooth and support_finetuning: - if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): - return Schema(self.cn_dataset_schema)(dataset_config) - # check dataset meets FT style - # NOTE: all FT subsets should have "metadata_file" - elif all(["metadata_file" in subset for subset in subsets_config]): - return Schema(self.ft_dataset_schema)(dataset_config) - # check dataset meets DB style - # NOTE: all DB subsets should have no "metadata_file" - elif all(["metadata_file" not in subset for subset in subsets_config]): - return Schema(self.db_dataset_schema)(dataset_config) - else: - raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) - self.dataset_schema = validate_flex_dataset - elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema - elif support_finetuning: - self.dataset_schema = self.ft_dataset_schema - elif support_controlnet: - self.dataset_schema = self.cn_dataset_schema + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) - self.general_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, - self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema - self.user_config_validator = Schema({ - "general": self.general_schema, - "datasets": [self.dataset_schema], - }) + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) - self.argparse_schema = self.__merge_dict( - self.general_schema, - self.ARGPARSE_SPECIFIC_SCHEMA, - {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, - {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, - ) + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) - self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) - def sanitize_user_config(self, user_config: dict) -> dict: - try: - return self.user_config_validator(user_config) - except MultipleInvalid: - # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") - raise + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) - # NOTE: In nature, argument parser result is not needed to be sanitize - # However this will help us to detect program bug - def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: - try: - return self.argparse_config_validator(argparse_namespace) - except MultipleInvalid: - # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") - raise + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + print("Invalid user config / ユーザ設定の形式が正しくないようです") + raise - # NOTE: value would be overwritten by latter dict if there is already the same key - @staticmethod - def __merge_dict(*dict_list: dict) -> dict: - merged = {} - for schema in dict_list: - # merged |= schema - for k, v in schema.items(): - merged[k] = v - return merged + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged class BlueprintGenerator: - BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { - } + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} - def __init__(self, sanitizer: ConfigSanitizer): - self.sanitizer = sanitizer + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer - # runtime_params is for parameters which is only configurable on runtime, such as tokenizer - def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: - sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) - sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) - # convert argparse namespace to dict like config - # NOTE: it is ok to have extra entries in dict - optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME - argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } - general_config = sanitized_user_config.get("general", {}) + general_config = sanitized_user_config.get("general", {}) - dataset_blueprints = [] - for dataset_config in sanitized_user_config.get("datasets", []): - # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets - subsets = dataset_config.get("subsets", []) - is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) - if is_controlnet: - subset_params_klass = ControlNetSubsetParams - dataset_params_klass = ControlNetDatasetParams - elif is_dreambooth: - subset_params_klass = DreamBoothSubsetParams - dataset_params_klass = DreamBoothDatasetParams - else: - subset_params_klass = FineTuningSubsetParams - dataset_params_klass = FineTuningDatasetParams + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams - subset_blueprints = [] - for subset_config in subsets: - params = self.generate_params_by_fallbacks(subset_params_klass, - [subset_config, dataset_config, general_config, argparse_config, runtime_params]) - subset_blueprints.append(SubsetBlueprint(params)) + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) - params = self.generate_params_by_fallbacks(dataset_params_klass, - [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) - dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) - return Blueprint(dataset_group_blueprint) + return Blueprint(dataset_group_blueprint) - @staticmethod - def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): - name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME - search_value = BlueprintGenerator.search_value - default_params = asdict(param_klass()) - param_names = default_params.keys() + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() - params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} - return param_klass(**params) + return param_klass(**params) - @staticmethod - def search_value(key: str, fallbacks: Sequence[dict], default_value = None): - for cand in fallbacks: - value = cand.get(key) - if value is not None: - return value + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value - return default_value + return default_value def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset - elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset - else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) - datasets.append(dataset) + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ [Dataset {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} - """) + """ + ) - if dataset.enable_bucket: - info += indent(dedent(f"""\ + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") - else: - info += "\n" + \n""" + ), + " ", + ) + else: + info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} @@ -475,147 +509,176 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - """), " ") + """ + ), + " ", + ) - if is_dreambooth: - info += indent(dedent(f"""\ + if is_dreambooth: + info += indent( + dedent( + f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ metadata_file: {subset.metadata_file} - \n"""), " ") + \n""" + ), + " ", + ) - print(info) + print(info) - # make buckets first because it determines the length of dataset - # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no - for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") - dataset.make_buckets() - dataset.set_seed(seed) + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) - return DatasetGroup(datasets) + return DatasetGroup(datasets) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): - def extract_dreambooth_params(name: str) -> Tuple[int, str]: - tokens = name.split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") - return 0, "" - caption_by_folder = '_'.join(tokens[1:]) - return n_repeats, caption_by_folder + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder - def generate(base_dir: Optional[str], is_reg: bool): - if base_dir is None: - return [] + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config subsets_config = [] - for subdir in base_dir.iterdir(): - if not subdir.is_dir(): - continue - - num_repeats, class_tokens = extract_dreambooth_params(subdir.name) - if num_repeats < 1: - continue - - subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir, False) - subsets_config += generate(reg_data_dir, True) - return subsets_config +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] -def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): - def generate(base_dir: Optional[str]): - if base_dir is None: - return [] + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + return subsets_config subsets_config = [] - subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir) - - return subsets_config - def load_user_config(file: str) -> dict: - file: Path = Path(file) - if not file.is_file(): - raise ValueError(f"file not found / ファイルが見つかりません: {file}") + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") - if file.name.lower().endswith('.json'): - try: - with open(file, 'r') as f: - config = json.load(f) - except Exception: - print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - elif file.name.lower().endswith('.toml'): - try: - config = toml.load(file) - except Exception: - print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - else: - raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + if file.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + print( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file.name.lower().endswith(".toml"): + try: + config = toml.load(file) + except Exception: + print( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + + return config - return config # for config test if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--support_dreambooth", action="store_true") - parser.add_argument("--support_finetuning", action="store_true") - parser.add_argument("--support_controlnet", action="store_true") - parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("dataset_config") - config_args, remain = parser.parse_known_args() + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() - parser = argparse.ArgumentParser() - train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) - train_util.add_training_arguments(parser, config_args.support_dreambooth) - argparse_namespace = parser.parse_args(remain) - train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + print("[argparse_namespace]") + print(vars(argparse_namespace)) - user_config = load_user_config(config_args.dataset_config) + user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + print("\n[user_config]") + print(user_config) - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) - sanitized_user_config = sanitizer.sanitize_user_config(user_config) + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + print("\n[sanitized_user_config]") + print(sanitized_user_config) - blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + print("\n[blueprint]") + print(blueprint) From fef172966fff02a5f918840843eb613ee1a6d50e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 16:24:43 +0900 Subject: [PATCH 07/13] Add network_multiplier for dataset and train LoRA --- README.md | 31 +++++++++++++++++++++++++++++ library/config_util.py | 3 +++ library/train_util.py | 45 +++++++++++++++++++++++++----------------- train_network.py | 13 +++++++++++- 4 files changed, 73 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 663f52e8..c1043783 100644 --- a/README.md +++ b/README.md @@ -257,11 +257,42 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. +- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc. + - This is an experimental option and may be removed or changed in the future. + - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. + - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. + - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. + - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 - PyTorch 2.1 以降が必要です。 - PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。 - 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。 +- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。 + - 実験的オプションのため、将来的に削除または仕様変更される可能性があります。 + - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 + - また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 + - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 + +- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 + +```toml +[general] +[[datasets]] +resolution = 512 +batch_size = 8 +network_multiplier = 1.0 + +... subset settings ... + +[[datasets]] +resolution = 512 +batch_size = 8 +network_multiplier = -1.0 + +... subset settings ... +``` + ### Jan 17, 2024 / 2024/1/17: v0.8.1 diff --git a/library/config_util.py b/library/config_util.py index 716cecff..a98c2b90 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -93,6 +93,7 @@ class BaseDatasetParams: tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 debug_dataset: bool = False @@ -219,6 +220,7 @@ class ConfigSanitizer: "max_bucket_reso": int, "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config @@ -469,6 +471,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} """ ) diff --git a/library/train_util.py b/library/train_util.py index 21e7638d..4ac6728b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset): tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], max_token_length: int, resolution: Optional[Tuple[int, int]], + network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() @@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset): self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] @@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset): for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + loss_weights.append( + self.prior_loss_weight if image_info.is_reg else 1.0 + ) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset): example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) example["flippeds"] = flippeds + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2039,6 +2048,8 @@ def debug_dataset(train_dataset, show_input_ids=False): print( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) + if "network_multipliers" in example: + print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: print(f"input ids: {iid}") @@ -2105,8 +2116,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2850,14 +2861,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") parser.add_argument( - "--dynamo_backend", - type=str, - default="inductor", + "--dynamo_backend", + type=str, + default="inductor", # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], - help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( @@ -2904,9 +2915,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 - parser.add_argument( - "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" - ) + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") parser.add_argument( "--ddp_timeout", type=int, @@ -3889,7 +3898,7 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) - + # torch.compile のオプション。 NO の場合は torch.compile は使わない dynamo_backend = "NO" if args.torch_compile: diff --git a/train_network.py b/train_network.py index b1291ed1..ef7d4197 100644 --- a/train_network.py +++ b/train_network.py @@ -310,6 +310,7 @@ class NetworkTrainer: ) if network is None: return + network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) @@ -768,7 +769,17 @@ class NetworkTrainer: accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor - b_size = latents.shape[0] + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + network.set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning From c59249a664a9202d8c7e2300587480d0a59a4c9b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 18:45:54 +0900 Subject: [PATCH 08/13] Add options to reduce memory usage in extract_lora_from_models.py closes #1059 --- README.md | 7 +++ networks/extract_lora_from_models.py | 83 ++++++++++++++++++++++++---- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c1043783..be6df588 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. +- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. + - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. + - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 @@ -273,6 +276,10 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 - また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 +- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 + - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 + - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。 + - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 6357df55..b9027adb 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -43,6 +43,9 @@ def svd( clamp_quantile=0.99, min_diff=0.01, no_metadata=False, + load_precision=None, + load_original_model_to=None, + load_tuned_model_to=None, ): def str_to_dtype(p): if p == "float": @@ -57,28 +60,51 @@ def svd( if v_parameterization is None: v_parameterization = v2 + load_dtype = str_to_dtype(load_precision) if load_precision else None save_dtype = str_to_dtype(save_precision) + work_device = "cpu" # load models if not sdxl: print(f"loading original SD model : {model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] + if load_dtype is not None: + text_encoder_o = text_encoder_o.to(load_dtype) + unet_o = unet_o.to(load_dtype) + print(f"loading tuned SD model : {model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] + if load_dtype is not None: + text_encoder_t = text_encoder_t.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: + device_org = load_original_model_to if load_original_model_to else "cpu" + device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) text_encoders_o = [text_encoder_o1, text_encoder_o2] + if load_dtype is not None: + text_encoder_o1 = text_encoder_o1.to(load_dtype) + text_encoder_o2 = text_encoder_o2.to(load_dtype) + unet_o = unet_o.to(load_dtype) + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) text_encoders_t = [text_encoder_t1, text_encoder_t2] + if load_dtype is not None: + text_encoder_t1 = text_encoder_t1.to(load_dtype) + text_encoder_t2 = text_encoder_t2.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha @@ -100,38 +126,54 @@ def svd( lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) + + # clear weight to save memory + module_o.weight = None + module_t.weight = None # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - diff = diff.float() diffs[lora_name] = diff + # clear target Text Encoder to save memory + for text_encoder in text_encoders_t: + del text_encoder + if not text_encoder_different: print("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] - diffs = {} + diffs = {} # clear diffs for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - diff = diff.float() + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - if args.device: - diff = diff.to(args.device) + # clear weight to save memory + module_o.weight = None + module_t.weight = None diffs[lora_name] = diff + # clear LoRA network, target U-Net to save memory + del lora_network_o + del lora_network_t + del unet_t + # make LoRA with svd print("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): + if args.device: + mat = mat.to(args.device) + mat = mat.to(torch.float) # calc by float + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] @@ -171,8 +213,8 @@ def svd( U = U.reshape(out_dim, rank, 1, 1) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - U = U.to("cpu").contiguous() - Vh = Vh.to("cpu").contiguous() + U = U.to(work_device, dtype=save_dtype).contiguous() + Vh = Vh.to(work_device, dtype=save_dtype).contiguous() lora_weights[lora_name] = (U, Vh) @@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" ) + parser.add_argument( + "--load_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" + ) parser.add_argument( "--save_precision", type=str, @@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--load_original_model_to", + type=str, + default=None, + help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) + parser.add_argument( + "--load_tuned_model_to", + type=str, + default=None, + help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) return parser From e0a3c69223d09f38a2fc0d86ed289378cf0f7958 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 18:47:10 +0900 Subject: [PATCH 09/13] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index be6df588..23108b78 100644 --- a/README.md +++ b/README.md @@ -264,7 +264,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. - Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. - - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. + - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 @@ -278,7 +278,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 - `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 - - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。 + - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。 - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 From 696dd7f66802091f973ef0ce5ffa1e8002e90789 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 22 Jan 2024 12:43:37 +0900 Subject: [PATCH 10/13] Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network --- library/sdxl_lpw_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 0562e88a..03b18256 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -925,7 +925,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: unet_dtype = self.unet.dtype dtype = unet_dtype - if dtype.itemsize == 1: # fp8 + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 dtype = torch.float16 self.unet.to(dtype) From 711b40ccda0143cbf0014249ecdb9353231cbb79 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 23 Jan 2024 11:49:03 +0800 Subject: [PATCH 11/13] Avoid always sync --- train_network.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index ef7d4197..9036f486 100644 --- a/train_network.py +++ b/train_network.py @@ -847,10 +847,11 @@ class NetworkTrainer: loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - self.all_reduce_network(accelerator, network) # sync DDP grad manually - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(network).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() From 6805cafa9be066e549b3aa2b9f53ec03a91ffda6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 23 Jan 2024 20:17:19 +0900 Subject: [PATCH 12/13] fix TI training crashes in multigpu #1019 --- train_textual_inversion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1..f1cf6fbd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -505,7 +505,7 @@ class TextualInversionTrainer: if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -730,14 +730,13 @@ class TextualInversionTrainer: is_main_process = accelerator.is_main_process if is_main_process: text_encoder = accelerator.unwrap_model(text_encoder) + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() accelerator.end_training() if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) From 7cb44e45023295548788449167bcc8cd6afc5417 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 23 Jan 2024 21:02:40 +0900 Subject: [PATCH 13/13] update readme --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 23108b78..0de787b8 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History -### 作業中の内容 / Work in progress +### Jan 23, 2024 / 2024/1/23: v0.8.2 - [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. @@ -266,6 +266,10 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. +- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf! +- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx! +- Fixed a bug in multi-GPU Textual Inversion training. + - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 - PyTorch 2.1 以降が必要です。 @@ -279,7 +283,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。 - `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。 - `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。 - +- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。 +- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。 +- マルチ GPU での Textual Inversion 学習の不具合を修正しました。 - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例