Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -4,9 +4,16 @@ This repository contains training, generation and utility scripts for Stable Dif
SD3 training is done with `sd3_train.py`. SD3 training is done with `sd3_train.py`.
__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! __Jul 27, 2024__:
- Latents and text encoder outputs caching mechanism is refactored significantly.
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
- With this change, dataset initialization is significantly faster, especially for large datasets.
Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). - Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures.
- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training.
---
`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. `fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
@@ -14,7 +21,7 @@ Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. t5xxl works with `fp16` now.
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
@@ -32,6 +39,14 @@ cache_latents = true
cache_latents_to_disk = true cache_latents_to_disk = true
``` ```
__2024/7/27:__
Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。
データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。
SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。
--- ---
[__Change History__](#change-history) is moved to the bottom of the page. [__Change History__](#change-history) is moved to the bottom of the page.

View File

@@ -10,7 +10,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library import deepspeed_utils from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
@@ -39,6 +39,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction, scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation, apply_debiased_estimation,
) )
import library.strategy_sd as strategy_sd
def train(args): def train(args):
@@ -52,7 +53,15 @@ def train(args):
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
@@ -81,10 +90,10 @@ def train(args):
] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else: else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -165,8 +174,9 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -192,6 +202,9 @@ def train(args):
else: else:
text_encoder.eval() text_encoder.eval()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents: if not cache_latents:
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
@@ -214,7 +227,11 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する # prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
@@ -317,7 +334,9 @@ def train(args):
) )
# For --sample_at_first # For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
@@ -342,8 +361,9 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder): with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning # Get the text embedding for conditioning
if args.weighted_captions: if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings( encoder_hidden_states = get_weighted_text_embeddings(
tokenizer, tokenize_strategy.tokenizer,
text_encoder, text_encoder,
batch["captions"], batch["captions"],
accelerator.device, accelerator.device,
@@ -351,10 +371,12 @@ def train(args):
clip_skip=args.clip_skip, clip_skip=args.clip_skip,
) )
else: else:
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = text_encoding_strategy.encode_tokens(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype tokenize_strategy, [text_encoder], [input_ids]
) )[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
@@ -409,7 +431,7 @@ def train(args):
global_step += 1 global_step += 1
train_util.sample_images( train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
) )
# 指定ステップごとにモデルを保存 # 指定ステップごとにモデルを保存
@@ -472,7 +494,9 @@ def train(args):
vae, vae,
) )
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:

View File

@@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams):
@dataclass @dataclass
class BaseDatasetParams: class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0 network_multiplier: float = 1.0
debug_dataset: bool = False debug_dataset: bool = False

View File

@@ -38,7 +38,7 @@ class SDTokenizer:
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
""" """
self.tokenizer = tokenizer self.tokenizer: CLIPTokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
self.min_length = min_length self.min_length = min_length
empty = self.tokenizer("")["input_ids"] empty = self.tokenizer("")["input_ids"]
@@ -56,6 +56,19 @@ class SDTokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8 self.max_word_length = 8
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
"""
Tokenize the text without weights.
"""
if type(text) == str:
text = [text]
batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
# return tokens["input_ids"]
pad_token = self.end_token if self.pad_with_end else 0
for tokens in batch_tokens["input_ids"]:
assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}"
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
@@ -75,13 +88,14 @@ class SDTokenizer:
for word in to_tokenize: for word in to_tokenize:
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
batch.append((self.end_token, 1.0)) batch.append((self.end_token, 1.0))
print(len(batch), self.max_length, self.min_length)
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length: if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
# truncate to max_length # truncate to max_length
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
if truncate_to_max_length and len(batch) > self.max_length: if truncate_to_max_length and len(batch) > self.max_length:
batch = batch[: self.max_length] batch = batch[: self.max_length]
if truncate_length is not None and len(batch) > truncate_length: if truncate_length is not None and len(batch) > truncate_length:
@@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer):
class SD3Tokenizer: class SD3Tokenizer:
def __init__(self, t5xxl=True): def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256):
if t5xxl_max_length is None:
t5xxl_max_length = 256
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
# self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
self.t5xxl = T5XXLTokenizer() if t5xxl else None self.t5xxl = T5XXLTokenizer() if t5xxl else None
# t5xxl has 99999999 max length, clip has 77 # t5xxl has 99999999 max length, clip has 77
self.model_max_length = self.clip_l.max_length # 77 self.t5xxl_max_length = t5xxl_max_length
def tokenize_with_weights(self, text: str): def tokenize_with_weights(self, text: str):
# temporary truncate to max_length even for t5xxl
return ( return (
self.clip_l.tokenize_with_weights(text), self.clip_l.tokenize_with_weights(text),
self.clip_g.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text),
( (
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length)
if self.t5xxl is not None if self.t5xxl is not None
else None else None
), ),
) )
def tokenize(self, text: str):
return (
self.clip_l.tokenize(text),
self.clip_g.tokenize(text),
(self.t5xxl.tokenize(text) if self.t5xxl is not None else None),
)
# endregion # endregion
@@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder:
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
list_of_tokens.append(tokens) list_of_tokens.append(tokens)
else: else:
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] if isinstance(list_of_token_weight_pairs[0], torch.Tensor):
list_of_tokens = [list(list_of_token_weight_pairs[0])]
else:
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
out, pooled = self(list_of_tokens) out, pooled = self(list_of_tokens)
if has_batch: if has_batch:
@@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel):
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
################################################################################################# #################################################################################################
"""
class T5XXLTokenizer(SDTokenizer): class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer):
max_length=99999999, max_length=99999999,
min_length=77, min_length=77,
) )
"""
class T5LayerNorm(torch.nn.Module): class T5LayerNorm(torch.nn.Module):

View File

@@ -280,111 +280,6 @@ def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.vae = None
def set_vae(self, vae: sd3_models.SDVAE):
self.vae = vae
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
with torch.no_grad():
latents_tensors = self.vae.encode(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = self.vae.encode(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
if self.cache_to_disk:
kwargs = {}
if flipped_latent is not None:
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
info.latents_npz,
latents=latents.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
else:
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
if not train_util.HIGH_VRAM:
clean_memory_on_device(self.vae.device)
# region Diffusers # region Diffusers

View File

@@ -384,6 +384,7 @@ def get_cond(
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
print(t5_tokens)
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)

View File

@@ -327,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
) )
def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument( parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
) )

328
library/strategy_base.py Normal file
View File

@@ -0,0 +1,328 @@
# base class for platform strategies. this file defines the interface for strategies
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy
def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
"""
if max_length is None:
max_length = tokenizer.model_max_length - 2
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
return input_ids
class TextEncodingStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
return cls._strategy
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._is_partial = is_partial
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def is_partial(self):
return self._is_partial
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
):
raise NotImplementedError
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
raise NotImplementedError
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _defualt_is_disk_cached_latents_expected(
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
with torch.no_grad():
latents_tensors = encode_by_vae(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = encode_by_vae(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
if self.cache_to_disk:
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)

139
library/strategy_sd.py Normal file
View File

@@ -0,0 +1,139 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
class SdTokenizeStrategy(TokenizeStrategy):
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
"""
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
if v2:
self.tokenizer = self._load_tokenizer(
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
# tokens: b,n,77
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
# does not include old npz
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

229
library/strategy_sd3.py Normal file
View File

@@ -0,0 +1,229 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class Sd3TokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, g_tokens, t5_tokens]
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
clip_l, clip_g, t5xxl = models
l_tokens, g_tokens, t5_tokens = tokens
if l_tokens is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
else:
t5_out = None
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
return [lg_out, t5_out, lg_pooled]
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if t5_out is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, abs_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(self.get_outputs_npz_path(abs_path))
if "clip_l" not in npz or "clip_g" not in npz:
return False
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
return False
# t5xxl is optional
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None
return [lg_out, t5_out, lg_pooled]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
captions = [info.caption for info in infos]
clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out is not None and t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
if t5_out is not None:
t5_out = t5_out.cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i] if t5_out is not None else None
lg_pooled_i = lg_pooled[i]
if self.cache_to_disk:
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for Sd3TokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = Sd3TokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

247
library/strategy_sdxl.py Normal file
View File

@@ -0,0 +1,247 @@
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
if max_length is None:
self.max_length = self.tokenizer1.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return (
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def _pool_workaround(
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# Create a mask where the EOS tokens are
eos_token_mask = (input_ids == eos_token_id).int()
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# get hidden states for EOS token
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
]
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output
def _get_hidden_states_sdxl(
self,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
input_ids2 = input_ids2.to(text_encoder2.device)
# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
return hidden_states1, hidden_states2, pool2
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Args:
tokenize_strategy: TokenizeStrategy
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
tokens: List of tokens, for text_encoder1 and text_encoder2
"""
if len(models) == 2:
text_encoder1, text_encoder2 = models
unwrapped_text_encoder2 = None
else:
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
tokens1, tokens2 = tokens
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
)
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, abs_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(self.get_outputs_npz_path(abs_path))
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [info.caption for info in infos]
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]

View File

@@ -12,6 +12,7 @@ import re
import shutil import shutil
import time import time
from typing import ( from typing import (
Any,
Dict, Dict,
List, List,
NamedTuple, NamedTuple,
@@ -34,6 +35,7 @@ from tqdm import tqdm
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
init_ipex() init_ipex()
@@ -81,10 +83,6 @@ logger = logging.getLogger(__name__)
# from library.hypernetwork import replace_attentions_for_hypernetwork # from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
HIGH_VRAM = False HIGH_VRAM = False
# checkpointファイル名 # checkpointファイル名
@@ -148,18 +146,24 @@ class ImageInfo:
self.image_size: Tuple[int, int] = None self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None
self.latents: torch.Tensor = None self.latents: Optional[torch.Tensor] = None
self.latents_flipped: torch.Tensor = None self.latents_flipped: Optional[torch.Tensor] = None
self.latents_npz: str = None self.latents_npz: Optional[str] = None # set in cache_latents
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
self.cond_img_path: str = None None # crop left top right bottom in original pixel size, not latents size
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image self.image: Optional[Image.Image] = None # optional, original PIL Image
# SDXL, optional self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
self.text_encoder_outputs_npz: Optional[str] = None
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
@@ -359,47 +363,6 @@ class AugHelper:
return self.color_aug if use_color_aug else None return self.color_aug if use_color_aug else None
class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
raise NotImplementedError
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
class BaseSubset: class BaseSubset:
def __init__( def __init__(
self, self,
@@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset):
class BaseDataset(torch.utils.data.Dataset): class BaseDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]], resolution: Optional[Tuple[int, int]],
network_multiplier: float, network_multiplier: float,
debug_dataset: bool, debug_dataset: bool,
) -> None: ) -> None:
super().__init__() super().__init__()
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False # width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier self.network_multiplier = network_multiplier
@@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_no_upscale = None self.bucket_no_upscale = None
self.bucket_info = None # for metadata self.bucket_info = None # for metadata
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0 self.current_step: int = 0
@@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset):
# caching # caching
self.caching_mode = None # None, 'latents', 'text' self.caching_mode = None # None, 'latents', 'text'
self.tokenize_strategy = None
self.text_encoder_output_caching_strategy = None
self.latents_caching_strategy = None
def set_current_strategies(self):
self.tokenize_strategy = TokenizeStrategy.get_strategy()
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
def set_seed(self, seed): def set_seed(self, seed):
self.seed = seed self.seed = seed
@@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset):
for batch_index in range(batch_count): for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
#  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
#
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# # そのためバッチサイズを画像種類までに制限する
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# # TO DO 正則化画像をepochまたがりで利用する仕組み
# num_of_image_types = len(set(bucket))
# bucket_batch_size = min(self.batch_size, num_of_image_types)
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
# for batch_index in range(batch_count):
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
# ↑ここまで
self.shuffle_buckets() self.shuffle_buckets()
self._length = len(self.buckets_indices) self._length = len(self.buckets_indices)
@@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset):
] ]
) )
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): def new_cache_latents(self, model: Any, is_main_process: bool):
r""" r"""
a brand new method to cache latents. This method caches latents with caching strategy. a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified. normal cache_latents method is used by default, but this method is used when caching strategy is specified.
""" """
logger.info("caching latents with caching strategy.") logger.info("caching latents with caching strategy.")
caching_strategy = LatentsCachingStrategy.get_strategy()
image_infos = list(self.image_data.values()) image_infos = list(self.image_data.values())
# sort by resolution # sort by resolution
@@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching latents...") logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)): for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset):
for batch in tqdm(batches, smoothing=1, total=len(batches)): for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
"""
tokenize_strategy = TokenizeStrategy.get_strategy()
text_encoding_strategy = TextEncodingStrategy.get_strategy()
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
batch_size = caching_strategy.batch_size or self.batch_size
# if cache to disk, don't cache TE outputs in non-main process
if caching_strategy.cache_to_disk and not is_main_process:
return
logger.info("caching Text Encoder outputs with caching strategy.")
image_infos = list(self.image_data.values())
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
continue
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
if len(batches) == 0:
logger.info("no Text Encoder outputs to cache")
return
# iterate batches
logger.info("caching Text Encoder outputs...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
# to support SD1/2, it needs a flag for v2, but it is postponed # to support SD1/2, it needs a flag for v2, but it is postponed
@@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset):
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.") logger.info("caching text encoder outputs.")
tokenize_strategy = TokenizeStrategy.get_strategy()
if batch_size is None: if batch_size is None:
batch_size = self.batch_size batch_size = self.batch_size
@@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset):
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2)) batch.append((info, input_ids1, input_ids2))
else: else:
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
batch.append((info, l_tokens, g_tokens, t5_tokens)) batch.append((info, l_tokens, g_tokens, t5_tokens))
if len(batch) >= batch_size: if len(batch) >= batch_size:
@@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset):
loss_weights = [] loss_weights = []
captions = [] captions = []
input_ids_list = [] input_ids_list = []
input_ids2_list = []
latents_list = [] latents_list = []
alpha_mask_list = [] alpha_mask_list = []
images = [] images = []
@@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset):
crop_top_lefts = [] crop_top_lefts = []
target_sizes_hw = [] target_sizes_hw = []
flippeds = [] # 変数名が微妙 flippeds = [] # 変数名が微妙
text_encoder_outputs1_list = [] text_encoder_outputs_list = []
text_encoder_outputs2_list = []
text_encoder_pool2_list = []
for image_key in bucket[image_index : image_index + bucket_batch_size]: for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key] image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key] subset = self.image_to_subset[image_key]
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0 # in case of fine tuning, is_reg is always False
) # in case of fine tuning, is_reg is always False loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
@@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset):
image = None image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
)
if flipped: if flipped:
latents = flipped_latents latents = flipped_latents
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
@@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset):
# captionとtext encoder outputを処理する # captionとtext encoder outputを処理する
caption = image_info.caption # default caption = image_info.caption # default
if image_info.text_encoder_outputs1 is not None:
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) tokenization_required = (
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
text_encoder_pool2_list.append(image_info.text_encoder_pool2) )
captions.append(caption) text_encoder_outputs = None
input_ids = None
if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
elif image_info.text_encoder_outputs_npz is not None: elif image_info.text_encoder_outputs_npz is not None:
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( # on disk
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz image_info.text_encoder_outputs_npz
) )
text_encoder_outputs1_list.append(text_encoder_outputs1)
text_encoder_outputs2_list.append(text_encoder_outputs2)
text_encoder_pool2_list.append(text_encoder_pool2)
captions.append(caption)
else: else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
caption = self.process_caption(subset, image_info.caption) caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers: input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
caption_layer = [] # if self.XTI_layers:
for layer in self.XTI_layers: # caption_layer = []
token_strings_from = " ".join(self.token_strings) # for layer in self.XTI_layers:
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) # token_strings_from = " ".join(self.token_strings)
caption_ = caption.replace(token_strings_from, token_strings_to) # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_layer.append(caption_) # caption_ = caption.replace(token_strings_from, token_strings_to)
captions.append(caption_layer) # caption_layer.append(caption_)
else: # captions.append(caption_layer)
captions.append(caption) # else:
# captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future # if not self.token_padding_disabled: # this option might be omitted in future
# TODO get_input_ids must support SD3 # # TODO get_input_ids must support SD3
if self.XTI_layers: # if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
else: # else:
token_caption = self.get_input_ids(caption, self.tokenizers[0]) # token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption) # input_ids_list.append(token_caption)
if len(self.tokenizers) > 1: # if len(self.tokenizers) > 1:
if self.XTI_layers: # if self.XTI_layers:
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
else: # else:
token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) # token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2_list.append(token_caption2) # input_ids2_list.append(token_caption2)
input_ids_list.append(input_ids)
captions.append(caption)
def none_or_stack_elements(tensors_list, converter):
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
return None
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
example = {} example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights) example["loss_weights"] = torch.FloatTensor(loss_weights)
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
if len(text_encoder_outputs1_list) == 0: example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
if len(self.tokenizers) > 1:
example["input_ids2"] = self.tokenizer[1](
captions, padding=True, truncation=True, return_tensors="pt"
).input_ids
else:
example["input_ids2"] = None
else:
example["input_ids"] = torch.stack(input_ids_list)
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
example["text_encoder_outputs1_list"] = None
example["text_encoder_outputs2_list"] = None
example["text_encoder_pool2_list"] = None
else:
example["input_ids"] = None
example["input_ids2"] = None
# # for assertion
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
# if one of alpha_masks is not None, we need to replace None with ones # if one of alpha_masks is not None, we need to replace None with ones
none_or_not = [x is None for x in alpha_mask_list] none_or_not = [x is None for x in alpha_mask_list]
@@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset):
self, self,
subsets: Sequence[DreamBoothSubset], subsets: Sequence[DreamBoothSubset],
batch_size: int, batch_size: int,
tokenizer,
max_token_length,
resolution, resolution,
network_multiplier: float, network_multiplier: float,
enable_bucket: bool, enable_bucket: bool,
@@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset):
prior_loss_weight: float, prior_loss_weight: float,
debug_dataset: bool, debug_dataset: bool,
) -> None: ) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) super().__init__(resolution, network_multiplier, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です" assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset):
# new caching: get image size from cache files # new caching: get image size from cache files
strategy = LatentsCachingStrategy.get_strategy() strategy = LatentsCachingStrategy.get_strategy()
if strategy is not None: if strategy is not None:
logger.info("get image size from cache files") logger.info("get image size from name of cache files")
size_set_count = 0 size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)): for i, img_path in enumerate(tqdm(img_paths)):
w, h = strategy.get_image_size_from_image_absolute_path(img_path) w, h = strategy.get_image_size_from_disk_cache_path(img_path)
if w is not None and h is not None: if w is not None and h is not None:
sizes[i] = [w, h] sizes[i] = [w, h]
size_set_count += 1 size_set_count += 1
@@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset):
self, self,
subsets: Sequence[FineTuningSubset], subsets: Sequence[FineTuningSubset],
batch_size: int, batch_size: int,
tokenizer,
max_token_length,
resolution, resolution,
network_multiplier: float, network_multiplier: float,
enable_bucket: bool, enable_bucket: bool,
@@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset):
bucket_no_upscale: bool, bucket_no_upscale: bool,
debug_dataset: bool, debug_dataset: bool,
) -> None: ) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) super().__init__(resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size self.batch_size = batch_size
@@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset):
self, self,
subsets: Sequence[ControlNetSubset], subsets: Sequence[ControlNetSubset],
batch_size: int, batch_size: int,
tokenizer,
max_token_length,
resolution, resolution,
network_multiplier: float, network_multiplier: float,
enable_bucket: bool, enable_bucket: bool,
@@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool, bucket_no_upscale: bool,
debug_dataset: float, debug_dataset: float,
) -> None: ) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) super().__init__(resolution, network_multiplier, debug_dataset)
db_subsets = [] db_subsets = []
for subset in subsets: for subset in subsets:
@@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset):
self.dreambooth_dataset_delegate = DreamBoothDataset( self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets, db_subsets,
batch_size, batch_size,
tokenizer,
max_token_length,
resolution, resolution,
network_multiplier, network_multiplier,
enable_bucket, enable_bucket,
@@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset):
self.conditioning_image_transforms = IMAGE_TRANSFORMS self.conditioning_image_transforms = IMAGE_TRANSFORMS
def set_current_strategies(self):
return self.dreambooth_dataset_delegate.set_current_strategies()
def make_buckets(self): def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets() self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
@@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def new_cache_latents(self, model: Any, is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
def __len__(self): def __len__(self):
return self.dreambooth_dataset_delegate.__len__() return self.dreambooth_dataset_delegate.__len__()
@@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets: # for dataset in self.datasets:
# dataset.make_buckets() # dataset.make_buckets()
def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy):
"""
DataLoader is run in multiple processes, so we need to set the strategy manually.
"""
for dataset in self.datasets:
dataset.set_text_encoder_output_caching_strategy(strategy)
def enable_XTI(self, *args, **kwargs): def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets: for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs) dataset.enable_XTI(*args, **kwargs)
@@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): def new_cache_latents(self, model: Any, is_main_process: bool):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(is_main_process, strategy) dataset.new_cache_latents(model, is_main_process)
def cache_text_encoder_outputs( def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
) )
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_text_encoder_outputs(models, is_main_process)
def set_caching_mode(self, caching_mode): def set_caching_mode(self, caching_mode):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_caching_mode(caching_mode) dataset.set_caching_mode(caching_mode)
@@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_text_encoder_output_cacheable(self) -> bool: def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def set_current_strategies(self):
for dataset in self.datasets:
dataset.set_current_strategies()
def set_current_epoch(self, epoch): def set_current_epoch(self, epoch):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_current_epoch(epoch) dataset.set_current_epoch(epoch)
@@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
# TODO update to use CachingStrategy # TODO update to use CachingStrategy
def load_latents_from_disk( # def load_latents_from_disk(
npz_path, # npz_path,
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: # ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path) # npz = np.load(npz_path)
if "latents" not in npz: # if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}") # raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"] # latents = npz["latents"]
original_size = npz["original_size"].tolist() # original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist() # crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None # flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None # alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask # return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): # def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
kwargs = {} # kwargs = {}
if flipped_latents_tensor is not None: # if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() # kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None: # if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() # kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez( # np.savez(
npz_path, # npz_path,
latents=latents_tensor.float().cpu().numpy(), # latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size), # original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb), # crop_ltrb=np.array(crop_ltrb),
**kwargs, # **kwargs,
) # )
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
@@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False):
example = train_dataset[idx] example = train_dataset[idx]
if example["latents"] is not None: if example["latents"] is not None:
logger.info(f"sample has latents from npz file: {example['latents'].size()}") logger.info(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate(
zip( zip(
example["image_keys"], example["image_keys"],
example["captions"], example["captions"],
example["loss_weights"], example["loss_weights"],
example["input_ids"], # example["input_ids"],
example["original_sizes_hw"], example["original_sizes_hw"],
example["crop_top_lefts"], example["crop_top_lefts"],
example["target_sizes_hw"], example["target_sizes_hw"],
@@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False):
if "network_multipliers" in example: if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}") print(f"network multiplier: {example['network_multipliers'][j]}")
if show_input_ids: # if show_input_ids:
logger.info(f"input ids: {iid}") # logger.info(f"input ids: {iid}")
if "input_ids2" in example: # if "input_ids2" in example:
logger.info(f"input ids2: {example['input_ids2'][j]}") # logger.info(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None: if example["images"] is not None:
im = example["images"][j] im = example["images"][j]
logger.info(f"image size: {im.size()}") logger.info(f"image size: {im.size()}")
@@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive):
class MinimalDataset(BaseDataset): class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): def __init__(self, resolution, network_multiplier, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) super().__init__(resolution, network_multiplier, debug_dataset)
self.num_train_images = 0 # update in subclass self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass
@@ -2773,14 +2783,15 @@ def cache_batch_latents(
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk: if cache_to_disk:
save_latents_to_disk( # save_latents_to_disk(
info.latents_npz, # info.latents_npz,
latent, # latent,
info.latents_original_size, # info.latents_original_size,
info.latents_crop_ltrb, # info.latents_crop_ltrb,
flipped_latent, # flipped_latent,
alpha_mask, # alpha_mask,
) # )
pass
else: else:
info.latents = latent info.latents = latent
if flip_aug: if flip_aug:
@@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
) )
def load_tokenizer(args: argparse.Namespace):
logger.info("prepare tokenizer")
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
logger.info(f"update token length: {args.max_token_length}")
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def prepare_accelerator(args: argparse.Namespace): def prepare_accelerator(args: argparse.Namespace):
""" """
this function also prepares deepspeed plugin this function also prepares deepspeed plugin
@@ -5550,6 +5534,7 @@ def sample_images_common(
): ):
""" """
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
TODO Use strategies here
""" """
if steps == 0: if steps == 0:

View File

@@ -24,7 +24,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils from library import sd3_models, sd3_utils, strategy_sd3
def get_noise(seed, latent): def get_noise(seed, latent):
@@ -145,6 +145,7 @@ if __name__ == "__main__":
parser.add_argument("--clip_g", type=str, required=False) parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
parser.add_argument("--prompt", type=str, default="A photo of a cat") parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
@@ -247,7 +248,7 @@ if __name__ == "__main__":
# load tokenizers # load tokenizers
logger.info("Loading tokenizers...") logger.info("Loading tokenizers...")
tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
# load models # load models
# logger.info("Create MMDiT from SD3 checkpoint...") # logger.info("Create MMDiT from SD3 checkpoint...")
@@ -320,12 +321,19 @@ if __name__ == "__main__":
# prepare embeddings # prepare embeddings
logger.info("Encoding prompts...") logger.info("Encoding prompts...")
# embeds, pooled_embed encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
cond = torch.cat([lg_out, t5_out], dim=-2), pooled
lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# generate image # generate image
logger.info("Generating image...") logger.info("Generating image...")

View File

@@ -17,7 +17,7 @@ init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3
from library.sdxl_train_util import match_mixed_precision from library.sdxl_train_util import match_mixed_precision
# , sdxl_model_util # , sdxl_model_util
@@ -69,10 +69,22 @@ def train(args):
# not args.train_text_encoder # not args.train_text_encoder
# ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません"
# training without text encoder cache is not supported # # training without text encoder cache is not supported: because T5XXL must be cached
assert ( # assert (
args.cache_text_encoder_outputs # args.cache_text_encoder_outputs
), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません"
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
+ " / text encoderの学習時はtext encoderの出力はキャッシュできませんt5xxlのみキャッシュすることは可能です"
)
if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs:
logger.warning(
"use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled."
+ " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります"
)
args.cache_text_encoder_outputs = True
# if args.block_lr: # if args.block_lr:
# block_lrs = [float(lr) for lr in args.block_lr.split(",")] # block_lrs = [float(lr) for lr in args.block_lr.split(",")]
@@ -88,17 +100,17 @@ def train(args):
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
# load tokenizer # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
sd3_tokenizer = sd3_models.SD3Tokenizer() if args.cache_latents:
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
# prepare caching strategy
if args.new_caching:
latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
) )
else: strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
latents_caching_strategy = None
train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # load tokenizer and prepare tokenize strategy
sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length)
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
@@ -153,6 +165,16 @@ def train(args):
train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認
if args.debug_dataset: if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
False,
)
)
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True) train_util.debug_dataset(train_dataset_group, True)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
@@ -215,19 +237,8 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
if not args.new_caching: train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
with torch.no_grad():
train_dataset_group.cache_latents(
vae_wrapper,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
file_suffix="_sd3.npz",
)
else:
latents_caching_strategy.set_vae(vae)
train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy)
vae.to("cpu") # if no sampling, vae can be deleted vae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -246,60 +257,70 @@ def train(args):
t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load)
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# should be deleted after caching text encoder outputs when not training text encoder
# this strategy should not be used other than this process
text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
train_clip_l = False train_clip_l = False
train_clip_g = False train_clip_g = False
train_t5xxl = False train_t5xxl = False
# if args.train_text_encoder: if args.train_text_encoder:
# # TODO each option for two text encoders? accelerator.print("enable text encoder training")
# accelerator.print("enable text encoder training") if args.gradient_checkpointing:
# if args.gradient_checkpointing: clip_l.gradient_checkpointing_enable()
# text_encoder1.gradient_checkpointing_enable() clip_g.gradient_checkpointing_enable()
# text_encoder2.gradient_checkpointing_enable() lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
# lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
# lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train train_clip_l = lr_te1 != 0
# train_clip_l = lr_te1 != 0 train_clip_g = lr_te2 != 0
# train_clip_g = lr_te2 != 0
if not train_clip_l:
clip_l.to(weight_dtype)
if not train_clip_g:
clip_g.to(weight_dtype)
clip_l.requires_grad_(train_clip_l)
clip_g.requires_grad_(train_clip_g)
clip_l.train(train_clip_l)
clip_g.train(train_clip_g)
else:
clip_l.to(weight_dtype)
clip_g.to(weight_dtype)
clip_l.requires_grad_(False)
clip_g.requires_grad_(False)
clip_l.eval()
clip_g.eval()
# # caching one text encoder output is not supported
# if not train_clip_l:
# text_encoder1.to(weight_dtype)
# if not train_clip_g:
# text_encoder2.to(weight_dtype)
# text_encoder1.requires_grad_(train_clip_l)
# text_encoder2.requires_grad_(train_clip_g)
# text_encoder1.train(train_clip_l)
# text_encoder2.train(train_clip_g)
# else:
clip_l.to(weight_dtype)
clip_g.to(weight_dtype)
clip_l.requires_grad_(False)
clip_g.requires_grad_(False)
clip_l.eval()
clip_g.eval()
if t5xxl is not None: if t5xxl is not None:
t5xxl.to(t5xxl_dtype) t5xxl.to(t5xxl_dtype)
t5xxl.requires_grad_(False) t5xxl.requires_grad_(False)
t5xxl.eval() t5xxl.eval()
# TextEncoderの出力をキャッシュする # cache text encoder outputs
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad # Text Encodes are eval and no grad here
clip_l.to(accelerator.device)
clip_g.to(accelerator.device)
if t5xxl is not None:
t5xxl.to(t5xxl_device)
with torch.no_grad(), accelerator.autocast(): text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
train_dataset_group.cache_text_encoder_outputs_sd3( args.cache_text_encoder_outputs_to_disk,
sd3_tokenizer, args.text_encoder_batch_size,
(clip_l, clip_g, t5xxl), False,
(accelerator.device, accelerator.device, t5xxl_device), train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
None, )
(None, None, None), strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
args.text_encoder_batch_size,
)
# TODO we can delete text encoders after caching clip_l.to(accelerator.device, dtype=weight_dtype)
clip_g.to(accelerator.device, dtype=weight_dtype)
if t5xxl is not None:
t5xxl.to(t5xxl_device, dtype=t5xxl_dtype)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# load MMDIT # load MMDIT
@@ -332,11 +353,11 @@ def train(args):
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
# if train_clip_l: # if train_clip_l:
# training_models.append(text_encoder1) # training_models.append(clip_l)
# params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
# if train_clip_g: # if train_clip_g:
# training_models.append(text_encoder2) # training_models.append(clip_g)
# params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
# calculate number of trainable parameters # calculate number of trainable parameters
n_params = 0 n_params = 0
@@ -344,7 +365,7 @@ def train(args):
for p in group["params"]: for p in group["params"]:
n_params += p.numel() n_params += p.numel()
accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}")
accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}") accelerator.print(f"number of trainable parameters: {n_params}")
@@ -398,7 +419,11 @@ def train(args):
else: else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する # prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
@@ -455,8 +480,8 @@ def train(args):
# TODO check if this is necessary. SD3 uses pool for clip_l and clip_g # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g
# # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
# if train_clip_l: # if train_clip_l:
# text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) # clip_l.text_model.encoder.layers[-1].requires_grad_(False)
# text_encoder1.text_model.final_layer_norm.requires_grad_(False) # clip_l.text_model.final_layer_norm.requires_grad_(False)
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
@@ -484,9 +509,8 @@ def train(args):
ds_model = deepspeed_utils.prepare_deepspeed_model( ds_model = deepspeed_utils.prepare_deepspeed_model(
args, args,
mmdit=mmdit, mmdit=mmdit,
# mmdie=mmdit if train_mmdit else None, clip_l=clip_l if train_clip_l else None,
# text_encoder1=text_encoder1 if train_clip_l else None, clip_g=clip_g if train_clip_g else None,
# text_encoder2=text_encoder2 if train_clip_g else None,
) )
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -498,10 +522,10 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if train_mmdit: if train_mmdit:
mmdit = accelerator.prepare(mmdit) mmdit = accelerator.prepare(mmdit)
# if train_clip_l: if train_clip_l:
# text_encoder1 = accelerator.prepare(text_encoder1) clip_l = accelerator.prepare(clip_l)
# if train_clip_g: if train_clip_g:
# text_encoder2 = accelerator.prepare(text_encoder2) clip_g = accelerator.prepare(clip_g)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
@@ -613,7 +637,7 @@ def train(args):
# # For --sample_at_first # # For --sample_at_first
# sd3_train_utils.sample_images( # sd3_train_utils.sample_images(
# accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit
# ) # )
# following function will be moved to sd3_train_utils # following function will be moved to sd3_train_utils
@@ -666,6 +690,7 @@ def train(args):
return weighting return weighting
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
@@ -687,37 +712,45 @@ def train(args):
# encode images to latents. images are [-1, 1] # encode images to latents. images are [-1, 1]
latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える # NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)): if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros") accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents) latents = torch.nan_to_num(latents, 0, out=latents)
# latents = latents * sdxl_model_util.VAE_SCALE_FACTOR # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
latents = sd3_models.SDVAE.process_in(latents) latents = sd3_models.SDVAE.process_in(latents)
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
# not cached, get text encoder outputs if text_encoder_outputs_list is not None:
# XXX This does not work yet lg_out, t5_out, lg_pooled = text_encoder_outputs_list
input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] if args.use_t5xxl_cache_only:
lg_out = None
lg_pooled = None
else:
lg_out = None
t5_out = None
lg_pooled = None
if lg_out is None or (train_clip_l or train_clip_g):
# not cached or training, so get from text encoders
input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder): with torch.set_grad_enabled(args.train_text_encoder):
# TODO support weighted captions # TODO support weighted captions
# TODO support length > 75
input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
input_ids_clip_g = input_ids_clip_g.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None]
# get text encoder outputs: outputs are concatenated
context, pool = sd3_utils.get_cond_from_tokens(
input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl
) )
else:
# encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) if t5_out is None:
# encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) _, _, input_ids_t5xxl = batch["input_ids_list"]
# pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) with torch.no_grad():
# TODO this reuses SDXL keys, it should be fixed input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
lg_out = batch["text_encoder_outputs1_list"] _, t5_out, _ = text_encoding_strategy.encode_tokens(
t5_out = batch["text_encoder_outputs2_list"] sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl]
pool = batch["text_encoder_pool2_list"] )
context = torch.cat([lg_out, t5_out], dim=-2)
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
@@ -748,13 +781,13 @@ def train(args):
if torch.any(torch.isnan(context)): if torch.any(torch.isnan(context)):
accelerator.print("NaN found in context, replacing with zeros") accelerator.print("NaN found in context, replacing with zeros")
context = torch.nan_to_num(context, 0, out=context) context = torch.nan_to_num(context, 0, out=context)
if torch.any(torch.isnan(pool)): if torch.any(torch.isnan(lg_pooled)):
accelerator.print("NaN found in pool, replacing with zeros") accelerator.print("NaN found in pool, replacing with zeros")
pool = torch.nan_to_num(pool, 0, out=pool) lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled)
# call model # call model
with accelerator.autocast(): with accelerator.autocast():
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs. # Preconditioning of the model outputs.
@@ -806,7 +839,7 @@ def train(args):
# accelerator.device, # accelerator.device,
# vae, # vae,
# [tokenizer1, tokenizer2], # [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2], # [clip_l, clip_g],
# mmdit, # mmdit,
# ) # )
@@ -875,7 +908,7 @@ def train(args):
# accelerator.device, # accelerator.device,
# vae, # vae,
# [tokenizer1, tokenizer2], # [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2], # [clip_l, clip_g],
# mmdit, # mmdit,
# ) # )
@@ -924,7 +957,19 @@ def setup_parser() -> argparse.ArgumentParser:
custom_train_functions.add_custom_train_arguments(parser) custom_train_functions.add_custom_train_arguments(parser)
sd3_train_utils.add_sd3_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser)
# parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument(
"--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する"
)
# parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する")
parser.add_argument(
"--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする"
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
)
# TE training is disabled temporarily # TE training is disabled temporarily
# parser.add_argument( # parser.add_argument(
@@ -962,7 +1007,6 @@ def setup_parser() -> argparse.ArgumentParser:
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
) )
parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う")
parser.add_argument( parser.add_argument(
"--skip_latents_validity_check", "--skip_latents_validity_check",
action="store_true", action="store_true",

View File

@@ -17,7 +17,7 @@ init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from library import deepspeed_utils, sdxl_model_util from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
import library.train_util as train_util import library.train_util as train_util
@@ -124,7 +124,16 @@ def train(args):
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
@@ -166,10 +175,10 @@ def train(args):
] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else: else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -262,8 +271,9 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -276,6 +286,9 @@ def train(args):
train_text_encoder1 = False train_text_encoder1 = False
train_text_encoder2 = False train_text_encoder2 = False
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if args.train_text_encoder: if args.train_text_encoder:
# TODO each option for two text encoders? # TODO each option for two text encoders?
accelerator.print("enable text encoder training") accelerator.print("enable text encoder training")
@@ -307,16 +320,17 @@ def train(args):
# TextEncoderの出力をキャッシュする # TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad # Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast(): text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
train_dataset_group.cache_text_encoder_outputs( args.cache_text_encoder_outputs_to_disk, None, False
(tokenizer1, tokenizer2), )
(text_encoder1, text_encoder2), strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
accelerator.device,
None, text_encoder1.to(accelerator.device)
args.cache_text_encoder_outputs_to_disk, text_encoder2.to(accelerator.device)
accelerator.is_main_process, with accelerator.autocast():
) train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
if not cache_latents: if not cache_latents:
vae.requires_grad_(False) vae.requires_grad_(False)
@@ -403,7 +417,11 @@ def train(args):
else: else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する # prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
@@ -597,7 +615,7 @@ def train(args):
# For --sample_at_first # For --sample_at_first
sdxl_train_util.sample_images( sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
) )
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
@@ -628,9 +646,15 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents) latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
input_ids1 = batch["input_ids"] if text_encoder_outputs_list is not None:
input_ids2 = batch["input_ids2"] # Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder): with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning # Get the text embedding for conditioning
# TODO support weighted captions # TODO support weighted captions
@@ -646,39 +670,13 @@ def train(args):
# else: # else:
input_ids1 = input_ids1.to(accelerator.device) input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
) )
else: if args.full_fp16:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) pool2 = pool2.to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoder1.device),
# batch["input_ids2"].to(text_encoder1.device),
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# logger.info("text encoder outputs verified")
# get size embeddings # get size embeddings
orig_size = batch["original_sizes_hw"] orig_size = batch["original_sizes_hw"]
@@ -765,7 +763,7 @@ def train(args):
global_step, global_step,
accelerator.device, accelerator.device,
vae, vae,
[tokenizer1, tokenizer2], tokenizers,
[text_encoder1, text_encoder2], [text_encoder1, text_encoder2],
unet, unet,
) )
@@ -847,7 +845,7 @@ def train(args):
global_step, global_step,
accelerator.device, accelerator.device,
vae, vae,
[tokenizer1, tokenizer2], tokenizers,
[text_encoder1, text_encoder2], [text_encoder1, text_encoder2],
unet, unet,
) )

View File

@@ -23,7 +23,16 @@ from accelerate.utils import set_seed
import accelerate import accelerate
from diffusers import DDPMScheduler, ControlNetModel from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file from safetensors.torch import load_file
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util from library import (
deepspeed_utils,
sai_model_spec,
sdxl_model_util,
sdxl_original_unet,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util import library.train_util as train_util
@@ -79,7 +88,14 @@ def train(args):
args.seed = random.randint(0, 2**32) args.seed = random.randint(0, 2**32)
set_seed(args.seed) set_seed(args.seed)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -106,7 +122,7 @@ def train(args):
] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
@@ -164,30 +180,30 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents( train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu") vae.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# TextEncoderの出力をキャッシュする # TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad # Text Encodes are eval and no grad
with torch.no_grad(): text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
train_dataset_group.cache_text_encoder_outputs( args.cache_text_encoder_outputs_to_disk, None, False
(tokenizer1, tokenizer2), )
(text_encoder1, text_encoder2), strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
accelerator.device,
None, text_encoder1.to(accelerator.device)
args.cache_text_encoder_outputs_to_disk, text_encoder2.to(accelerator.device)
accelerator.is_main_process, with accelerator.autocast():
) train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# prepare ControlNet-LLLite # prepare ControlNet-LLLite
@@ -242,7 +258,11 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -290,7 +310,7 @@ def train(args):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if isinstance(unet, DDP): if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter unet._set_static_graph() # avoid error for multiple use of the parameter
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
@@ -357,7 +377,9 @@ def train(args):
if args.log_tracker_config is not None: if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config) init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers( accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
) )
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
@@ -409,27 +431,26 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents) latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
input_ids2 = batch["input_ids2"] if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.no_grad(): with torch.no_grad():
# Get the text embedding for conditioning
input_ids1 = input_ids1.to(accelerator.device) input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
args.max_token_length, tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
) )
else: if args.full_fp16:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) pool2 = pool2.to(weight_dtype)
# get size embeddings # get size embeddings
orig_size = batch["original_sizes_hw"] orig_size = batch["original_sizes_hw"]

View File

@@ -1,16 +1,21 @@
import argparse import argparse
import torch import torch
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util
import train_network import train_network
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer): class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -49,15 +54,32 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args): def get_tokenize_strategy(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args) return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
return tokenizer
def is_text_encoder_outputs_cached(self, args): def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
return args.cache_text_encoder_outputs return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sdxl.SdxlTextEncodingStrategy()
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders + [accelerator.unwrap_model(text_encoders[-1])]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
else:
return None
def cache_text_encoder_outputs_if_needed( def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
): ):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
if not args.lowram: if not args.lowram:
@@ -70,15 +92,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast # When TE is not be trained, it will not be prepared so we need to use explicit autocast
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
with accelerator.autocast(): with accelerator.autocast():
dataset.cache_text_encoder_outputs( dataset.new_cache_text_encoder_outputs(
tokenizers, text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
) )
accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32) text_encoders[1].to("cpu", dtype=torch.float32)

View File

@@ -5,10 +5,10 @@ import regex
import torch import torch
from library.device_utils import init_ipex from library.device_utils import init_ipex
init_ipex() init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util
import train_textual_inversion import train_textual_inversion
@@ -41,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args): def get_tokenize_strategy(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args) return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
return tokenizer
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
input_ids1 = batch["input_ids"] return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
input_ids2 = batch["input_ids2"]
with torch.enable_grad(): def get_latents_caching_strategy(self, args):
input_ids1 = input_ids1.to(accelerator.device) latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
input_ids2 = input_ids2.to(accelerator.device) False, args.cache_latents_to_disk, args.vae_batch_size, False
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( )
args.max_token_length, return latents_caching_strategy
input_ids1,
input_ids2, def get_text_encoding_strategy(self, args):
tokenizers[0], return strategy_sdxl.SdxlTextEncodingStrategy()
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -81,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
sdxl_train_util.sample_images( sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
) )
def save_weights(self, file, updated_embs, save_dtype, metadata): def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -122,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser() parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser return parser

View File

@@ -11,7 +11,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library import deepspeed_utils from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
@@ -38,6 +38,7 @@ from library.custom_train_functions import (
apply_masked_loss, apply_masked_loss,
) )
from library.utils import setup_logging, add_logging_arguments from library.utils import setup_logging, add_logging_arguments
import library.strategy_sd as strategy_sd
setup_logging() setup_logging()
import logging import logging
@@ -58,7 +59,14 @@ def train(args):
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
@@ -80,10 +88,10 @@ def train(args):
] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else: else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -145,13 +153,17 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加 unet.requires_grad_(True) # 念のため追加
@@ -184,8 +196,11 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # prepare dataloader
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, train_dataset_group,
@@ -290,10 +305,16 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name} init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None: if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config) init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) accelerator.init_trackers(
"dreambooth" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# For --sample_at_first # For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
@@ -331,7 +352,7 @@ def train(args):
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions: if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings( encoder_hidden_states = get_weighted_text_embeddings(
tokenizer, tokenize_strategy.tokenizer,
text_encoder, text_encoder,
batch["captions"], batch["captions"],
accelerator.device, accelerator.device,
@@ -339,14 +360,18 @@ def train(args):
clip_skip=args.clip_skip, clip_skip=args.clip_skip,
) )
else: else:
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = text_encoding_strategy.encode_tokens(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype tokenize_strategy, [text_encoder], [input_ids]
) )[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual # Predict the noise residual
with accelerator.autocast(): with accelerator.autocast():
@@ -358,7 +383,9 @@ def train(args):
else: else:
target = noise target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
@@ -393,7 +420,7 @@ def train(args):
global_step += 1 global_step += 1
train_util.sample_images( train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
) )
# 指定ステップごとにモデルを保存 # 指定ステップごとにモデルを保存
@@ -457,7 +484,9 @@ def train(args):
vae, vae,
) )
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:

View File

@@ -7,6 +7,7 @@ import random
import time import time
import json import json
from multiprocessing import Value from multiprocessing import Value
from typing import Any, List
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -18,7 +19,7 @@ init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util import library.train_util as train_util
from library.train_util import DreamBoothDataset from library.train_util import DreamBoothDataset
@@ -101,19 +102,31 @@ class NetworkTrainer:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def load_tokenizer(self, args): def get_tokenize_strategy(self, args):
tokenizer = train_util.load_tokenizer(args) return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
return tokenizer
def is_text_encoder_outputs_cached(self, args): def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
return False return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
def get_text_encoder_outputs_caching_strategy(self, args):
return None
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders
def is_train_text_encoder(self, args): def is_train_text_encoder(self, args):
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) return not args.network_train_unet_only
def cache_text_encoder_outputs_if_needed( def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
for t_enc in text_encoders: for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype) t_enc.to(accelerator.device, dtype=weight_dtype)
@@ -123,7 +136,7 @@ class NetworkTrainer:
return encoder_hidden_states return encoder_hidden_states
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred return noise_pred
def all_reduce_network(self, accelerator, network): def all_reduce_network(self, accelerator, network):
@@ -131,8 +144,8 @@ class NetworkTrainer:
if param.grad is not None: if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean") param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
def train(self, args): def train(self, args):
session_id = random.randint(0, 2**32) session_id = random.randint(0, 2**32)
@@ -150,9 +163,13 @@ class NetworkTrainer:
args.seed = random.randint(0, 2**32) args.seed = random.randint(0, 2**32)
set_seed(args.seed) set_seed(args.seed)
# tokenizerは単体またはリスト、tokenizersは必ずリスト既存のコードとの互換性のため tokenize_strategy = self.get_tokenize_strategy(args)
tokenizer = self.load_tokenizer(args) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = self.get_latents_caching_strategy(args)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
@@ -194,11 +211,11 @@ class NetworkTrainer:
] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else: else:
# use arbitrary dataset class # use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -268,8 +285,9 @@ class NetworkTrainer:
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -277,9 +295,13 @@ class NetworkTrainer:
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed( text_encoding_strategy = self.get_text_encoding_strategy(args)
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
)
text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
if text_encoder_outputs_caching_strategy is not None:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
# prepare network # prepare network
net_kwargs = {} net_kwargs = {}
@@ -366,7 +388,11 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -878,7 +904,7 @@ class NetworkTrainer:
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
# For --sample_at_first # For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
# training loop # training loop
if initial_step > 0: # only if skip_until_initial_step is specified if initial_step > 0: # only if skip_until_initial_step is specified
@@ -933,21 +959,31 @@ class NetworkTrainer:
# print(f"set multiplier: {multipliers}") # print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers) accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
# Get the text embedding for conditioning if text_encoder_outputs_list is not None:
if args.weighted_captions: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
text_encoder_conds = get_weighted_text_embeddings( else:
tokenizer, with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
text_encoder, # Get the text embedding for conditioning
batch["captions"], if args.weighted_captions:
accelerator.device, # SD only
args.max_token_length // 75 if args.max_token_length else 1, text_encoder_conds = get_weighted_text_embeddings(
clip_skip=args.clip_skip, tokenizers[0],
) text_encoder,
else: batch["captions"],
text_encoder_conds = self.get_text_cond( accelerator.device,
args, accelerator, batch, tokenizers, text_encoders, weight_dtype args.max_token_length // 75 if args.max_token_length else 1,
) clip_skip=args.clip_skip,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
@@ -1026,7 +1062,9 @@ class NetworkTrainer:
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
# 指定ステップごとにモデルを保存 # 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1082,7 +1120,7 @@ class NetworkTrainer:
if args.save_state: if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
# end of epoch # end of epoch

View File

@@ -2,6 +2,7 @@ import argparse
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
from typing import Any, List
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -15,7 +16,7 @@ init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from library import deepspeed_utils, model_util from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util import library.train_util as train_util
import library.huggingface_util as huggingface_util import library.huggingface_util as huggingface_util
@@ -103,28 +104,38 @@ class TextualInversionTrainer:
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
def load_tokenizer(self, args): def get_tokenize_strategy(self, args):
tokenizer = train_util.load_tokenizer(args) return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
return tokenizer
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
return latents_caching_strategy
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
pass pass
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): def get_text_encoding_strategy(self, args):
with torch.enable_grad(): return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]:
return encoder_hidden_states return text_encoders
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
train_util.sample_images( train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement
) )
def save_weights(self, file, updated_embs, save_dtype, metadata): def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -182,8 +193,13 @@ class TextualInversionTrainer:
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer tokenize_strategy = self.get_tokenize_strategy(args)
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = self.get_latents_caching_strategy(args)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# acceleratorを準備する # acceleratorを準備する
logger.info("prepare accelerator") logger.info("prepare accelerator")
@@ -194,14 +210,7 @@ class TextualInversionTrainer:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む # モデルを読み込む
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
accelerator.print(
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
+ "accelerateでは複数のモデルテキストエンコーダーのgradient_accumulation_stepsはサポートされていないようです"
)
# Convert the init_word to token_id # Convert the init_word to token_id
init_token_ids_list = [] init_token_ids_list = []
@@ -310,10 +319,10 @@ class TextualInversionTrainer:
] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else: else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) train_dataset_group = train_util.load_arbitrary_dataset(args)
self.assert_extra_args(args, train_dataset_group) self.assert_extra_args(args, train_dataset_group)
@@ -368,11 +377,10 @@ class TextualInversionTrainer:
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -387,7 +395,11 @@ class TextualInversionTrainer:
trainable_params += text_encoder.get_input_embeddings().parameters() trainable_params += text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意 # DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
@@ -415,20 +427,8 @@ class TextualInversionTrainer:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if len(text_encoders) == 1: optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
index_no_updates_list = [] index_no_updates_list = []
orig_embeds_params_list = [] orig_embeds_params_list = []
@@ -456,6 +456,9 @@ class TextualInversionTrainer:
else: else:
unet.eval() unet.eval()
text_encoding_strategy = self.get_text_encoding_strategy(args)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
@@ -510,7 +513,9 @@ class TextualInversionTrainer:
if args.log_tracker_config is not None: if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config) init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers( accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
) )
# function for saving/removing # function for saving/removing
@@ -540,8 +545,8 @@ class TextualInversionTrainer:
global_step, global_step,
accelerator.device, accelerator.device,
vae, vae,
tokenizer_or_list, tokenizers,
text_encoder_or_list, text_encoders,
unet, unet,
prompt_replacement, prompt_replacement,
) )
@@ -568,7 +573,12 @@ class TextualInversionTrainer:
latents = latents * self.vae_scale_factor latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning # Get the text embedding for conditioning
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
@@ -588,7 +598,9 @@ class TextualInversionTrainer:
else: else:
target = noise target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
@@ -639,8 +651,8 @@ class TextualInversionTrainer:
global_step, global_step,
accelerator.device, accelerator.device,
vae, vae,
tokenizer_or_list, tokenizers,
text_encoder_or_list, text_encoders,
unet, unet,
prompt_replacement, prompt_replacement,
) )
@@ -722,8 +734,8 @@ class TextualInversionTrainer:
global_step, global_step,
accelerator.device, accelerator.device,
vae, vae,
tokenizer_or_list, tokenizers,
text_encoder_or_list, text_encoders,
unet, unet,
prompt_replacement, prompt_replacement,
) )