diff --git a/README.md b/README.md index 5d4f9621..d406fecd 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,16 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! +__Jul 27, 2024__: +- Latents and text encoder outputs caching mechanism is refactored significantly. + - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. + - With this change, dataset initialization is significantly faster, especially for large datasets. -Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. + +- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. + +--- `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -14,7 +21,7 @@ Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. +t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. @@ -32,6 +39,14 @@ cache_latents = true cache_latents_to_disk = true ``` +__2024/7/27:__ + +Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 + +データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 + +SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 + --- [__Change History__](#change-history) is moved to the bottom of the page. diff --git a/fine_tune.py b/fine_tune.py index d865cd2d..c9102f6c 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -39,6 +39,7 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +import library.strategy_sd as strategy_sd def train(args): @@ -52,7 +53,15 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -81,10 +90,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -165,8 +174,9 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -192,6 +202,9 @@ def train(args): else: text_encoder.eval() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -214,7 +227,11 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -317,7 +334,9 @@ def train(args): ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -342,8 +361,9 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: + # TODO move to strategy_sd.py encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -351,10 +371,12 @@ def train(args): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -409,7 +431,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -472,7 +494,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/library/config_util.py b/library/config_util.py index 10b2457f..f8cdfe60 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False diff --git a/library/sd3_models.py b/library/sd3_models.py index ec8e1bbd..28378c73 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -38,7 +38,7 @@ class SDTokenizer: サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. """ - self.tokenizer = tokenizer + self.tokenizer: CLIPTokenizer = tokenizer self.max_length = max_length self.min_length = min_length empty = self.tokenizer("")["input_ids"] @@ -56,6 +56,19 @@ class SDTokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + """ + Tokenize the text without weights. + """ + if type(text) == str: + text = [text] + batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") + # return tokens["input_ids"] + + pad_token = self.end_token if self.pad_with_end else 0 + for tokens in batch_tokens["input_ids"]: + assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" @@ -75,13 +88,14 @@ class SDTokenizer: for word in to_tokenize: batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) batch.append((self.end_token, 1.0)) + print(len(batch), self.max_length, self.min_length) if self.pad_to_max_length: batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer): class SD3Tokenizer: - def __init__(self, t5xxl=True): + def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): + if t5xxl_max_length is None: + t5xxl_max_length = 256 + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") self.t5xxl = T5XXLTokenizer() if t5xxl else None # t5xxl has 99999999 max length, clip has 77 - self.model_max_length = self.clip_l.max_length # 77 + self.t5xxl_max_length = t5xxl_max_length def tokenize_with_weights(self, text: str): - # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) if self.t5xxl is not None else None ), ) + def tokenize(self, text: str): + return ( + self.clip_l.tokenize(text), + self.clip_g.tokenize(text), + (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), + ) + # endregion @@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder: tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] list_of_tokens.append(tokens) else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + if isinstance(list_of_token_weight_pairs[0], torch.Tensor): + list_of_tokens = [list(list_of_token_weight_pairs[0])] + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] out, pooled = self(list_of_tokens) if has_batch: @@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel): ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl ################################################################################################# - +""" class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" def __init__(self): super().__init__( @@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer): max_length=99999999, min_length=77, ) +""" class T5LayerNorm(torch.nn.Module): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 24591219..8f99d947 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -280,111 +280,6 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): - SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - self.vae = None - - def set_vae(self, vae: sd3_models.SDVAE): - self.vae = vae - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX - ) - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) - - try: - npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop - ) - img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) - - with torch.no_grad(): - latents_tensors = self.vae.encode(img_tensor).to("cpu") - if flip_aug: - img_tensor = torch.flip(img_tensor, dims=[3]) - with torch.no_grad(): - flipped_latents = self.vae.encode(img_tensor).to("cpu") - else: - flipped_latents = [None] * len(latents_tensors) - - # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): - for i in range(len(image_infos)): - info = image_infos[i] - latents = latents_tensors[i] - flipped_latent = flipped_latents[i] - alpha_mask = alpha_masks[i] - original_size = original_sizes[i] - crop_ltrb = crop_ltrbs[i] - - if self.cache_to_disk: - kwargs = {} - if flipped_latent is not None: - kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - info.latents_npz, - latents=latents.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) - else: - info.latents = latents - if flip_aug: - info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - - if not train_util.HIGH_VRAM: - clean_memory_on_device(self.vae.device) - # region Diffusers diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 16f80c60..5849518f 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -384,6 +384,7 @@ def get_cond( dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + print(t5_tokens) return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91..f009b577 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise( ) -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 00000000..594cca5e --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,328 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _defualt_is_disk_cached_latents_expected( + self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + if self.cache_to_disk: + self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + npz = np.load(npz_path) + if "latents" not in npz: + raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_ltrb = npz["crop_ltrb"].tolist() + flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + ): + kwargs = {} + if flipped_latents_tensor is not None: + kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + npz_path, + latents=latents_tensor.float().cpu().numpy(), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), + **kwargs, + ) diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 00000000..10581614 --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,139 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from library import train_util +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + # does not include old npz + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 00000000..42630ab2 --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,229 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + clip_l, clip_g, t5xxl = models + + l_tokens, g_tokens, t5_tokens = tokens + if l_tokens is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_out, l_pooled = clip_l(l_tokens) + g_out, g_pooled = clip_g(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is not None and t5_tokens is not None: + t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + else: + t5_out = None + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + return [lg_out, t5_out, lg_pooled] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "clip_l" not in npz or "clip_g" not in npz: + return False + if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + return False + # t5xxl is optional + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] if "t5_out" in data else None + return [lg_out, t5_out, lg_pooled] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + captions = [info.caption for info in infos] + + clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out is not None and t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + if t5_out is not None: + t5_out = t5_out.cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] if t5_out is not None else None + lg_pooled_i = lg_pooled[i] + + if self.cache_to_disk: + kwargs = {} + if t5_out is not None: + kwargs["t5_out"] = t5_out_i + np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + else: + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for Sd3TokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = Sd3TokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 00000000..a4513336 --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,247 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index 7af0070e..a747e047 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import re import shutil import time from typing import ( + Any, Dict, List, NamedTuple, @@ -34,6 +35,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -81,10 +83,6 @@ logger = logging.getLogger(__name__) # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - HIGH_VRAM = False # checkpointファイル名 @@ -148,18 +146,24 @@ class ImageInfo: self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime @@ -359,47 +363,6 @@ class AugHelper: return self.color_aug if use_color_aug else None -class LatentsCachingStrategy: - _strategy = None # strategy instance: actual strategy class - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - self._cache_to_disk = cache_to_disk - self._batch_size = batch_size - self.skip_disk_cache_validity_check = skip_disk_cache_validity_check - - @classmethod - def set_strategy(cls, strategy): - if cls._strategy is not None: - raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") - cls._strategy = strategy - - @classmethod - def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: - return cls._strategy - - @property - def cache_to_disk(self): - return self._cache_to_disk - - @property - def batch_size(self): - return self._batch_size - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - raise NotImplementedError - - def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: - raise NotImplementedError - - def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool - ) -> bool: - raise NotImplementedError - - def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - raise NotImplementedError - - class BaseSubset: def __init__( self, @@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset): class BaseDataset(torch.utils.data.Dataset): def __init__( self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier @@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_step: int = 0 @@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() def set_seed(self, seed): self.seed = seed @@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset): for batch_index in range(batch_count): self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - self.shuffle_buckets() self._length = len(self.buckets_indices) @@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. """ logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() image_infos = list(self.image_data.values()) # sort by resolution @@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset): logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと @@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset): for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + # if cache to disk, don't cache TE outputs in non-main process + if caching_strategy.cache_to_disk and not is_main_process: + return + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + info.text_encoder_outputs_npz = te_out_npz + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset # to support SD1/2, it needs a flag for v2, but it is postponed @@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset): # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + tokenize_strategy = TokenizeStrategy.get_strategy() + if batch_size is None: batch_size = self.batch_size @@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset): input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) batch.append((info, input_ids1, input_ids2)) else: - l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= batch_size: @@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset): loss_weights = [] captions = [] input_ids_list = [] - input_ids2_list = [] latents_list = [] alpha_mask_list = [] images = [] @@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset): crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + text_encoder_outputs_list = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + ) if flipped: latents = flipped_latents alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem @@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset): # captionとtext encoder outputを処理する caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) + + if tokenization_required: caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - # TODO get_input_ids must support SD3 - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) + + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None - else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones none_or_not = [x is None for x in alpha_mask_list] @@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset): self, subsets: Sequence[DreamBoothSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset): prior_loss_weight: float, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset): # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: - logger.info("get image size from cache files") + logger.info("get image size from name of cache files") size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_image_absolute_path(img_path) + w, h = strategy.get_image_size_from_disk_cache_path(img_path) if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1 @@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset): self, subsets: Sequence[FineTuningSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset): self, subsets: Sequence[ControlNetSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset): self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, batch_size, - tokenizer, - max_token_length, resolution, network_multiplier, enable_bucket, @@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset): self.conditioning_image_transforms = IMAGE_TRANSFORMS + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def new_cache_latents(self, model: Any, is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): # for dataset in self.datasets: # dataset.make_buckets() + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) @@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(is_main_process, strategy) + dataset.new_cache_latents(model, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset): tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, is_main_process) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def is_text_encoder_output_cacheable(self) -> bool: return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # TODO update to use CachingStrategy -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) def debug_dataset(train_dataset, show_input_ids=False): @@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], example["captions"], example["loss_weights"], - example["input_ids"], + # example["input_ids"], example["original_sizes_hw"], example["crop_top_lefts"], example["target_sizes_hw"], @@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False): if "network_multipliers" in example: print(f"network multiplier: {example['network_multipliers'][j]}") - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] logger.info(f"image size: {im.size()}") @@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2773,14 +2783,15 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk( - info.latents_npz, - latent, - info.latents_original_size, - info.latents_crop_ltrb, - flipped_latent, - alpha_mask, - ) + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass else: info.latents = latent if flip_aug: @@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): ) -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - def prepare_accelerator(args: argparse.Namespace): """ this function also prepares deepspeed plugin @@ -5550,6 +5534,7 @@ def sample_images_common( ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here """ if steps == 0: diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index ffa0d46d..e9e61af1 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -24,7 +24,7 @@ import logging logger = logging.getLogger(__name__) -from library import sd3_models, sd3_utils +from library import sd3_models, sd3_utils, strategy_sd3 def get_noise(seed, latent): @@ -145,6 +145,7 @@ if __name__ == "__main__": parser.add_argument("--clip_g", type=str, required=False) parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -247,7 +248,7 @@ if __name__ == "__main__": # load tokenizers logger.info("Loading tokenizers...") - tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) # load models # logger.info("Create MMDiT from SD3 checkpoint...") @@ -320,12 +321,19 @@ if __name__ == "__main__": # prepare embeddings logger.info("Encoding prompts...") - # embeds, pooled_embed - lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - cond = torch.cat([lg_out, t5_out], dim=-2), pooled + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py index f34e4712..617e3027 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -69,10 +69,22 @@ def train(args): # not args.train_text_encoder # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - # training without text encoder cache is not supported - assert ( - args.cache_text_encoder_outputs - ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + # # training without text encoder cache is not supported: because T5XXL must be cached + # assert ( + # args.cache_text_encoder_outputs + # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( + "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" + ) + + if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs: + logger.warning( + "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled." + + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります" + ) + args.cache_text_encoder_outputs = True # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] @@ -88,17 +100,17 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # load tokenizer - sd3_tokenizer = sd3_models.SD3Tokenizer() - - # prepare caching strategy - if args.new_caching: - latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) - else: - latents_caching_strategy = None - train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # load tokenizer and prepare tokenize strategy + sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # データセットを準備する if args.dataset_class is None: @@ -153,6 +165,16 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: @@ -215,19 +237,8 @@ def train(args): vae.requires_grad_(False) vae.eval() - if not args.new_caching: - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - file_suffix="_sd3.npz", - ) - else: - latents_caching_strategy.set_vae(vae) - train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) @@ -246,60 +257,70 @@ def train(args): t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # should be deleted after caching text encoder outputs when not training text encoder + # this strategy should not be used other than this process + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_clip_l = False train_clip_g = False train_t5xxl = False - # if args.train_text_encoder: - # # TODO each option for two text encoders? - # accelerator.print("enable text encoder training") - # if args.gradient_checkpointing: - # text_encoder1.gradient_checkpointing_enable() - # text_encoder2.gradient_checkpointing_enable() - # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - # train_clip_l = lr_te1 != 0 - # train_clip_g = lr_te2 != 0 + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + train_clip_l = lr_te1 != 0 + train_clip_g = lr_te2 != 0 + + if not train_clip_l: + clip_l.to(weight_dtype) + if not train_clip_g: + clip_g.to(weight_dtype) + clip_l.requires_grad_(train_clip_l) + clip_g.requires_grad_(train_clip_g) + clip_l.train(train_clip_l) + clip_g.train(train_clip_g) + else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() - # # caching one text encoder output is not supported - # if not train_clip_l: - # text_encoder1.to(weight_dtype) - # if not train_clip_g: - # text_encoder2.to(weight_dtype) - # text_encoder1.requires_grad_(train_clip_l) - # text_encoder2.requires_grad_(train_clip_g) - # text_encoder1.train(train_clip_l) - # text_encoder2.train(train_clip_g) - # else: - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) - clip_l.requires_grad_(False) - clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() if t5xxl is not None: t5xxl.to(t5xxl_dtype) t5xxl.requires_grad_(False) t5xxl.eval() - # TextEncoderの出力をキャッシュする + # cache text encoder outputs if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(t5xxl_device) - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs_sd3( - sd3_tokenizer, - (clip_l, clip_g, t5xxl), - (accelerator.device, accelerator.device, t5xxl_device), - None, - (None, None, None), - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - args.text_encoder_batch_size, - ) + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) - # TODO we can delete text encoders after caching + clip_l.to(accelerator.device, dtype=weight_dtype) + clip_g.to(accelerator.device, dtype=weight_dtype) + if t5xxl is not None: + t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) accelerator.wait_for_everyone() # load MMDIT @@ -332,11 +353,11 @@ def train(args): # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) # if train_clip_l: - # training_models.append(text_encoder1) - # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # training_models.append(clip_l) + # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) # if train_clip_g: - # training_models.append(text_encoder2) - # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + # training_models.append(clip_g) + # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) # calculate number of trainable parameters n_params = 0 @@ -344,7 +365,7 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -398,7 +419,11 @@ def train(args): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -455,8 +480,8 @@ def train(args): # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer # if train_clip_l: - # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # clip_l.text_model.encoder.layers[-1].requires_grad_(False) + # clip_l.text_model.final_layer_norm.requires_grad_(False) # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する if args.cache_text_encoder_outputs: @@ -484,9 +509,8 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model( args, mmdit=mmdit, - # mmdie=mmdit if train_mmdit else None, - # text_encoder1=text_encoder1 if train_clip_l else None, - # text_encoder2=text_encoder2 if train_clip_g else None, + clip_l=clip_l if train_clip_l else None, + clip_g=clip_g if train_clip_g else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -498,10 +522,10 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: mmdit = accelerator.prepare(mmdit) - # if train_clip_l: - # text_encoder1 = accelerator.prepare(text_encoder1) - # if train_clip_g: - # text_encoder2 = accelerator.prepare(text_encoder2) + if train_clip_l: + clip_l = accelerator.prepare(clip_l) + if train_clip_g: + clip_g = accelerator.prepare(clip_g) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -613,7 +637,7 @@ def train(args): # # For --sample_at_first # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit # ) # following function will be moved to sd3_train_utils @@ -666,6 +690,7 @@ def train(args): return weighting loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -687,37 +712,45 @@ def train(args): # encode images to latents. images are [-1, 1] latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = sd3_models.SDVAE.process_in(latents) - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - # not cached, get text encoder outputs - # XXX This does not work yet - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + lg_out, t5_out, lg_pooled = text_encoder_outputs_list + if args.use_t5xxl_cache_only: + lg_out = None + lg_pooled = None + else: + lg_out = None + t5_out = None + lg_pooled = None + + if lg_out is None or (train_clip_l or train_clip_g): + # not cached or training, so get from text encoders + input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - # TODO support length > 75 input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) - - # get text encoder outputs: outputs are concatenated - context, pool = sd3_utils.get_cond_from_tokens( - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] ) - else: - # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - # TODO this reuses SDXL keys, it should be fixed - lg_out = batch["text_encoder_outputs1_list"] - t5_out = batch["text_encoder_outputs2_list"] - pool = batch["text_encoder_pool2_list"] - context = torch.cat([lg_out, t5_out], dim=-2) + + if t5_out is None: + _, _, input_ids_t5xxl = batch["input_ids_list"] + with torch.no_grad(): + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + _, t5_out, _ = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] + ) + + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -748,13 +781,13 @@ def train(args): if torch.any(torch.isnan(context)): accelerator.print("NaN found in context, replacing with zeros") context = torch.nan_to_num(context, 0, out=context) - if torch.any(torch.isnan(pool)): + if torch.any(torch.isnan(lg_pooled)): accelerator.print("NaN found in pool, replacing with zeros") - pool = torch.nan_to_num(pool, 0, out=pool) + lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled) # call model with accelerator.autocast(): - model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. @@ -806,7 +839,7 @@ def train(args): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -875,7 +908,7 @@ def train(args): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -924,7 +957,19 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" + ) + # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument( + "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" + ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", + ) # TE training is disabled temporarily # parser.add_argument( @@ -962,7 +1007,6 @@ def setup_parser() -> argparse.ArgumentParser: help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") parser.add_argument( "--skip_latents_validity_check", action="store_true", diff --git a/sdxl_train.py b/sdxl_train.py index ae92d6a3..b6d4afd6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl import library.train_util as train_util @@ -124,7 +124,16 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -166,10 +175,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -262,8 +271,9 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -276,6 +286,9 @@ def train(args): train_text_encoder1 = False train_text_encoder2 = False + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if args.train_text_encoder: # TODO each option for two text encoders? accelerator.print("enable text encoder training") @@ -307,16 +320,17 @@ def train(args): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + + accelerator.wait_for_everyone() if not cache_latents: vae.requires_grad_(False) @@ -403,7 +417,11 @@ def train(args): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -597,7 +615,7 @@ def train(args): # For --sample_at_first sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) loss_recorder = train_util.LossRecorder() @@ -628,9 +646,15 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning # TODO support weighted captions @@ -646,39 +670,13 @@ def train(args): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -765,7 +763,7 @@ def train(args): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -847,7 +845,7 @@ def train(args): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9..0eaec29b 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -23,7 +23,16 @@ from accelerate.utils import set_seed import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) import library.model_util as model_util import library.train_util as train_util @@ -79,7 +88,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,7 +122,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -164,30 +180,30 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # prepare ControlNet-LLLite @@ -242,7 +258,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -290,7 +310,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if isinstance(unet, DDP): - unet._set_static_graph() # avoid error for multiple use of the parameter + unet._set_static_graph() # avoid error for multiple use of the parameter if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる @@ -357,7 +377,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -409,27 +431,26 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1..67ccae62 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,16 +1,21 @@ import argparse import torch +from accelerate import Accelerator from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util import train_network from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() @@ -49,15 +54,32 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -70,15 +92,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, + dataset.new_cache_text_encoder_outputs( + text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process ) + accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e2..cbfcef55 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -5,10 +5,10 @@ import regex import torch from library.device_utils import init_ipex + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util - +from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util import train_textual_inversion @@ -41,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -81,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -122,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def setup_parser() -> argparse.ArgumentParser: parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) return parser diff --git a/train_db.py b/train_db.py index 39d8ea6e..7caee664 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device @@ -38,6 +38,7 @@ from library.custom_train_functions import ( apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd setup_logging() import logging @@ -58,7 +59,14 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -80,10 +88,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -145,13 +153,17 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 @@ -184,8 +196,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -290,10 +305,16 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -331,7 +352,7 @@ def train(args): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -339,14 +360,18 @@ def train(args): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -358,7 +383,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -393,7 +420,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -457,7 +484,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7ba07385..3828fed1 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import random import time import json from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -18,7 +19,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -101,19 +102,31 @@ class NetworkTrainer: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return False + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + return not args.network_train_unet_only - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) @@ -123,7 +136,7 @@ class NetworkTrainer: return encoder_hidden_states def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): @@ -131,8 +144,8 @@ class NetworkTrainer: if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) def train(self, args): session_id = random.randint(0, 2**32) @@ -150,9 +163,13 @@ class NetworkTrainer: args.seed = random.randint(0, 2**32) set_seed(args.seed) - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -194,11 +211,11 @@ class NetworkTrainer: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -268,8 +285,9 @@ class NetworkTrainer: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -277,9 +295,13 @@ class NetworkTrainer: # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -366,7 +388,11 @@ class NetworkTrainer: optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -878,7 +904,7 @@ class NetworkTrainer: os.remove(old_ckpt_file) # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -933,21 +959,31 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + else: + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + # SD only + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -1026,7 +1062,9 @@ class NetworkTrainer: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1082,7 +1120,7 @@ class NetworkTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # end of epoch diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c3..9044f50d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,6 +2,7 @@ import argparse import math import os from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -15,7 +16,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -103,28 +104,38 @@ class TextualInversionTrainer: def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]: + return text_encoders def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -182,8 +193,13 @@ class TextualInversionTrainer: if args.seed is not None: set_seed(args.seed) - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # acceleratorを準備する logger.info("prepare accelerator") @@ -194,14 +210,7 @@ class TextualInversionTrainer: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) + model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id init_token_ids_list = [] @@ -310,10 +319,10 @@ class TextualInversionTrainer: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) + train_dataset_group = train_util.load_arbitrary_dataset(args) self.assert_extra_args(args, train_dataset_group) @@ -368,11 +377,10 @@ class TextualInversionTrainer: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -387,7 +395,11 @@ class TextualInversionTrainer: trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -415,20 +427,8 @@ class TextualInversionTrainer: lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] orig_embeds_params_list = [] @@ -456,6 +456,9 @@ class TextualInversionTrainer: else: unet.eval() + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -510,7 +513,9 @@ class TextualInversionTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -540,8 +545,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -568,7 +573,12 @@ class TextualInversionTrainer: latents = latents * self.vae_scale_factor # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -588,7 +598,9 @@ class TextualInversionTrainer: else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -639,8 +651,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -722,8 +734,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, )