mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
21
README.md
21
README.md
@@ -4,9 +4,16 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||||||
|
|
||||||
SD3 training is done with `sd3_train.py`.
|
SD3 training is done with `sd3_train.py`.
|
||||||
|
|
||||||
__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza!
|
__Jul 27, 2024__:
|
||||||
|
- Latents and text encoder outputs caching mechanism is refactored significantly.
|
||||||
|
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
|
||||||
|
- With this change, dataset initialization is significantly faster, especially for large datasets.
|
||||||
|
|
||||||
Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result).
|
- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures.
|
||||||
|
|
||||||
|
- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
|
`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
|
||||||
|
|
||||||
@@ -14,7 +21,7 @@ Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the
|
|||||||
|
|
||||||
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
|
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
|
||||||
|
|
||||||
~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now.
|
t5xxl works with `fp16` now.
|
||||||
|
|
||||||
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
|
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
|
||||||
|
|
||||||
@@ -32,6 +39,14 @@ cache_latents = true
|
|||||||
cache_latents_to_disk = true
|
cache_latents_to_disk = true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
__2024/7/27:__
|
||||||
|
|
||||||
|
Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。
|
||||||
|
|
||||||
|
データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。
|
||||||
|
|
||||||
|
SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||||
|
|||||||
54
fine_tune.py
54
fine_tune.py
@@ -10,7 +10,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library import deepspeed_utils
|
from library import deepspeed_utils, strategy_base
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -39,6 +39,7 @@ from library.custom_train_functions import (
|
|||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
apply_debiased_estimation,
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
import library.strategy_sd as strategy_sd
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -52,7 +53,15 @@ def train(args):
|
|||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
if cache_latents:
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
@@ -81,10 +90,10 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -165,8 +174,9 @@ def train(args):
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
|
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -192,6 +202,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
|
|
||||||
|
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -214,7 +227,11 @@ def train(args):
|
|||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
@@ -317,7 +334,9 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(
|
||||||
|
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||||
|
)
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
@@ -342,8 +361,9 @@ def train(args):
|
|||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
|
# TODO move to strategy_sd.py
|
||||||
encoder_hidden_states = get_weighted_text_embeddings(
|
encoder_hidden_states = get_weighted_text_embeddings(
|
||||||
tokenizer,
|
tokenize_strategy.tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
batch["captions"],
|
batch["captions"],
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
@@ -351,10 +371,12 @@ def train(args):
|
|||||||
clip_skip=args.clip_skip,
|
clip_skip=args.clip_skip,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||||
encoder_hidden_states = train_util.get_hidden_states(
|
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
||||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
tokenize_strategy, [text_encoder], [input_ids]
|
||||||
)
|
)[0]
|
||||||
|
if args.full_fp16:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
@@ -409,7 +431,7 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
train_util.sample_images(
|
train_util.sample_images(
|
||||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||||
)
|
)
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
@@ -472,7 +494,9 @@ def train(args):
|
|||||||
vae,
|
vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(
|
||||||
|
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||||
|
)
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
@@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDatasetParams:
|
class BaseDatasetParams:
|
||||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
|
||||||
max_token_length: int = None
|
|
||||||
resolution: Optional[Tuple[int, int]] = None
|
resolution: Optional[Tuple[int, int]] = None
|
||||||
network_multiplier: float = 1.0
|
network_multiplier: float = 1.0
|
||||||
debug_dataset: bool = False
|
debug_dataset: bool = False
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class SDTokenizer:
|
|||||||
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
|
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
|
||||||
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
|
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
|
||||||
"""
|
"""
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer: CLIPTokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
empty = self.tokenizer("")["input_ids"]
|
empty = self.tokenizer("")["input_ids"]
|
||||||
@@ -56,6 +56,19 @@ class SDTokenizer:
|
|||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
self.max_word_length = 8
|
self.max_word_length = 8
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Tokenize the text without weights.
|
||||||
|
"""
|
||||||
|
if type(text) == str:
|
||||||
|
text = [text]
|
||||||
|
batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
|
||||||
|
# return tokens["input_ids"]
|
||||||
|
|
||||||
|
pad_token = self.end_token if self.pad_with_end else 0
|
||||||
|
for tokens in batch_tokens["input_ids"]:
|
||||||
|
assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}"
|
||||||
|
|
||||||
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
|
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
|
||||||
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
||||||
@@ -75,13 +88,14 @@ class SDTokenizer:
|
|||||||
for word in to_tokenize:
|
for word in to_tokenize:
|
||||||
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
|
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
|
||||||
batch.append((self.end_token, 1.0))
|
batch.append((self.end_token, 1.0))
|
||||||
|
print(len(batch), self.max_length, self.min_length)
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||||
|
|
||||||
# truncate to max_length
|
# truncate to max_length
|
||||||
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
|
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
|
||||||
if truncate_to_max_length and len(batch) > self.max_length:
|
if truncate_to_max_length and len(batch) > self.max_length:
|
||||||
batch = batch[: self.max_length]
|
batch = batch[: self.max_length]
|
||||||
if truncate_length is not None and len(batch) > truncate_length:
|
if truncate_length is not None and len(batch) > truncate_length:
|
||||||
@@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer):
|
|||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, t5xxl=True):
|
def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256):
|
||||||
|
if t5xxl_max_length is None:
|
||||||
|
t5xxl_max_length = 256
|
||||||
|
|
||||||
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
|
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
|
||||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||||
|
# self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
# self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||||
self.t5xxl = T5XXLTokenizer() if t5xxl else None
|
self.t5xxl = T5XXLTokenizer() if t5xxl else None
|
||||||
# t5xxl has 99999999 max length, clip has 77
|
# t5xxl has 99999999 max length, clip has 77
|
||||||
self.model_max_length = self.clip_l.max_length # 77
|
self.t5xxl_max_length = t5xxl_max_length
|
||||||
|
|
||||||
def tokenize_with_weights(self, text: str):
|
def tokenize_with_weights(self, text: str):
|
||||||
# temporary truncate to max_length even for t5xxl
|
|
||||||
return (
|
return (
|
||||||
self.clip_l.tokenize_with_weights(text),
|
self.clip_l.tokenize_with_weights(text),
|
||||||
self.clip_g.tokenize_with_weights(text),
|
self.clip_g.tokenize_with_weights(text),
|
||||||
(
|
(
|
||||||
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
|
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length)
|
||||||
if self.t5xxl is not None
|
if self.t5xxl is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def tokenize(self, text: str):
|
||||||
|
return (
|
||||||
|
self.clip_l.tokenize(text),
|
||||||
|
self.clip_g.tokenize(text),
|
||||||
|
(self.t5xxl.tokenize(text) if self.t5xxl is not None else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder:
|
|||||||
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
|
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
|
||||||
list_of_tokens.append(tokens)
|
list_of_tokens.append(tokens)
|
||||||
else:
|
else:
|
||||||
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
|
if isinstance(list_of_token_weight_pairs[0], torch.Tensor):
|
||||||
|
list_of_tokens = [list(list_of_token_weight_pairs[0])]
|
||||||
|
else:
|
||||||
|
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
|
||||||
|
|
||||||
out, pooled = self(list_of_tokens)
|
out, pooled = self(list_of_tokens)
|
||||||
if has_batch:
|
if has_batch:
|
||||||
@@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel):
|
|||||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||||
#################################################################################################
|
#################################################################################################
|
||||||
|
|
||||||
|
"""
|
||||||
class T5XXLTokenizer(SDTokenizer):
|
class T5XXLTokenizer(SDTokenizer):
|
||||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer):
|
|||||||
max_length=99999999,
|
max_length=99999999,
|
||||||
min_length=77,
|
min_length=77,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class T5LayerNorm(torch.nn.Module):
|
class T5LayerNorm(torch.nn.Module):
|
||||||
|
|||||||
@@ -280,111 +280,6 @@ def sample_images(*args, **kwargs):
|
|||||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
|
||||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
|
||||||
|
|
||||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
|
||||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
|
||||||
self.vae = None
|
|
||||||
|
|
||||||
def set_vae(self, vae: sd3_models.SDVAE):
|
|
||||||
self.vae = vae
|
|
||||||
|
|
||||||
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
|
||||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
|
|
||||||
if len(npz_file) == 0:
|
|
||||||
return None, None
|
|
||||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
|
||||||
return int(w), int(h)
|
|
||||||
|
|
||||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
|
||||||
return (
|
|
||||||
os.path.splitext(absolute_path)[0]
|
|
||||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
|
||||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
|
||||||
if not self.cache_to_disk:
|
|
||||||
return False
|
|
||||||
if not os.path.exists(npz_path):
|
|
||||||
return False
|
|
||||||
if self.skip_disk_cache_validity_check:
|
|
||||||
return True
|
|
||||||
|
|
||||||
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
|
|
||||||
|
|
||||||
try:
|
|
||||||
npz = np.load(npz_path)
|
|
||||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if flip_aug:
|
|
||||||
if "latents_flipped" not in npz:
|
|
||||||
return False
|
|
||||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if alpha_mask:
|
|
||||||
if "alpha_mask" not in npz:
|
|
||||||
return False
|
|
||||||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
if "alpha_mask" in npz:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading file: {npz_path}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
|
||||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
|
||||||
image_infos, alpha_mask, random_crop
|
|
||||||
)
|
|
||||||
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
latents_tensors = self.vae.encode(img_tensor).to("cpu")
|
|
||||||
if flip_aug:
|
|
||||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
|
||||||
with torch.no_grad():
|
|
||||||
flipped_latents = self.vae.encode(img_tensor).to("cpu")
|
|
||||||
else:
|
|
||||||
flipped_latents = [None] * len(latents_tensors)
|
|
||||||
|
|
||||||
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
|
||||||
for i in range(len(image_infos)):
|
|
||||||
info = image_infos[i]
|
|
||||||
latents = latents_tensors[i]
|
|
||||||
flipped_latent = flipped_latents[i]
|
|
||||||
alpha_mask = alpha_masks[i]
|
|
||||||
original_size = original_sizes[i]
|
|
||||||
crop_ltrb = crop_ltrbs[i]
|
|
||||||
|
|
||||||
if self.cache_to_disk:
|
|
||||||
kwargs = {}
|
|
||||||
if flipped_latent is not None:
|
|
||||||
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
|
|
||||||
if alpha_mask is not None:
|
|
||||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
|
||||||
np.savez(
|
|
||||||
info.latents_npz,
|
|
||||||
latents=latents.float().cpu().numpy(),
|
|
||||||
original_size=np.array(original_size),
|
|
||||||
crop_ltrb=np.array(crop_ltrb),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
info.latents = latents
|
|
||||||
if flip_aug:
|
|
||||||
info.latents_flipped = flipped_latent
|
|
||||||
info.alpha_mask = alpha_mask
|
|
||||||
|
|
||||||
if not train_util.HIGH_VRAM:
|
|
||||||
clean_memory_on_device(self.vae.device)
|
|
||||||
|
|
||||||
|
|
||||||
# region Diffusers
|
# region Diffusers
|
||||||
|
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ def get_cond(
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
|
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
|
||||||
|
print(t5_tokens)
|
||||||
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
|
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||||
)
|
)
|
||||||
|
|||||||
328
library/strategy_base.py
Normal file
328
library/strategy_base.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
# base class for platform strategies. this file defines the interface for strategies
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||||
|
|
||||||
|
|
||||||
|
# TODO remove circular import by moving ImageInfo to a separate file
|
||||||
|
# from library.train_util import ImageInfo
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeStrategy:
|
||||||
|
_strategy = None # strategy instance: actual strategy class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_strategy(cls, strategy):
|
||||||
|
if cls._strategy is not None:
|
||||||
|
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||||
|
cls._strategy = strategy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
|
||||||
|
return cls._strategy
|
||||||
|
|
||||||
|
def _load_tokenizer(
|
||||||
|
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
|
tokenizer = None
|
||||||
|
if tokenizer_cache_dir:
|
||||||
|
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
||||||
|
if os.path.exists(local_tokenizer_path):
|
||||||
|
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||||
|
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||||
|
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||||
|
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||||
|
tokenizer.save_pretrained(local_tokenizer_path)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
for SD1.5/2.0/SDXL
|
||||||
|
TODO support batch input
|
||||||
|
"""
|
||||||
|
if max_length is None:
|
||||||
|
max_length = tokenizer.model_max_length - 2
|
||||||
|
|
||||||
|
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
if max_length > tokenizer.model_max_length:
|
||||||
|
input_ids = input_ids.squeeze(0)
|
||||||
|
iids_list = []
|
||||||
|
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||||||
|
# v1
|
||||||
|
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||||
|
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||||
|
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||||
|
ids_chunk = (
|
||||||
|
input_ids[0].unsqueeze(0),
|
||||||
|
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||||
|
input_ids[-1].unsqueeze(0),
|
||||||
|
)
|
||||||
|
ids_chunk = torch.cat(ids_chunk)
|
||||||
|
iids_list.append(ids_chunk)
|
||||||
|
else:
|
||||||
|
# v2 or SDXL
|
||||||
|
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||||
|
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||||
|
ids_chunk = (
|
||||||
|
input_ids[0].unsqueeze(0), # BOS
|
||||||
|
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||||
|
input_ids[-1].unsqueeze(0),
|
||||||
|
) # PAD or EOS
|
||||||
|
ids_chunk = torch.cat(ids_chunk)
|
||||||
|
|
||||||
|
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||||
|
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||||
|
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||||||
|
ids_chunk[-1] = tokenizer.eos_token_id
|
||||||
|
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||||
|
if ids_chunk[1] == tokenizer.pad_token_id:
|
||||||
|
ids_chunk[1] = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
iids_list.append(ids_chunk)
|
||||||
|
|
||||||
|
input_ids = torch.stack(iids_list) # 3,77
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncodingStrategy:
|
||||||
|
_strategy = None # strategy instance: actual strategy class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_strategy(cls, strategy):
|
||||||
|
if cls._strategy is not None:
|
||||||
|
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||||
|
cls._strategy = strategy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
|
||||||
|
return cls._strategy
|
||||||
|
|
||||||
|
def encode_tokens(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Encode tokens into embeddings and outputs.
|
||||||
|
:param tokens: list of token tensors for each TextModel
|
||||||
|
:return: list of output embeddings for each architecture
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderOutputsCachingStrategy:
|
||||||
|
_strategy = None # strategy instance: actual strategy class
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||||
|
) -> None:
|
||||||
|
self._cache_to_disk = cache_to_disk
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||||
|
self._is_partial = is_partial
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_strategy(cls, strategy):
|
||||||
|
if cls._strategy is not None:
|
||||||
|
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||||
|
cls._strategy = strategy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||||
|
return cls._strategy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_to_disk(self):
|
||||||
|
return self._cache_to_disk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_partial(self):
|
||||||
|
return self._is_partial
|
||||||
|
|
||||||
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cache_batch_outputs(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsCachingStrategy:
|
||||||
|
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||||
|
|
||||||
|
_strategy = None # strategy instance: actual strategy class
|
||||||
|
|
||||||
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
self._cache_to_disk = cache_to_disk
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_strategy(cls, strategy):
|
||||||
|
if cls._strategy is not None:
|
||||||
|
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||||
|
cls._strategy = strategy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||||
|
return cls._strategy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_to_disk(self):
|
||||||
|
return self._cache_to_disk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(
|
||||||
|
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||||
|
) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _defualt_is_disk_cached_latents_expected(
|
||||||
|
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||||
|
):
|
||||||
|
if not self.cache_to_disk:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
return False
|
||||||
|
if self.skip_disk_cache_validity_check:
|
||||||
|
return True
|
||||||
|
|
||||||
|
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||||
|
|
||||||
|
try:
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if flip_aug:
|
||||||
|
if "latents_flipped" not in npz:
|
||||||
|
return False
|
||||||
|
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if alpha_mask:
|
||||||
|
if "alpha_mask" not in npz:
|
||||||
|
return False
|
||||||
|
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if "alpha_mask" in npz:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# TODO remove circular dependency for ImageInfo
|
||||||
|
def _default_cache_batch_latents(
|
||||||
|
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||||
|
"""
|
||||||
|
from library import train_util # import here to avoid circular import
|
||||||
|
|
||||||
|
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||||
|
image_infos, alpha_mask, random_crop
|
||||||
|
)
|
||||||
|
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
latents_tensors = encode_by_vae(img_tensor).to("cpu")
|
||||||
|
if flip_aug:
|
||||||
|
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||||
|
with torch.no_grad():
|
||||||
|
flipped_latents = encode_by_vae(img_tensor).to("cpu")
|
||||||
|
else:
|
||||||
|
flipped_latents = [None] * len(latents_tensors)
|
||||||
|
|
||||||
|
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||||||
|
for i in range(len(image_infos)):
|
||||||
|
info = image_infos[i]
|
||||||
|
latents = latents_tensors[i]
|
||||||
|
flipped_latent = flipped_latents[i]
|
||||||
|
alpha_mask = alpha_masks[i]
|
||||||
|
original_size = original_sizes[i]
|
||||||
|
crop_ltrb = crop_ltrbs[i]
|
||||||
|
|
||||||
|
if self.cache_to_disk:
|
||||||
|
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||||||
|
else:
|
||||||
|
info.latents_original_size = original_size
|
||||||
|
info.latents_crop_ltrb = crop_ltrb
|
||||||
|
info.latents = latents
|
||||||
|
if flip_aug:
|
||||||
|
info.latents_flipped = flipped_latent
|
||||||
|
info.alpha_mask = alpha_mask
|
||||||
|
|
||||||
|
def load_latents_from_disk(
|
||||||
|
self, npz_path: str
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if "latents" not in npz:
|
||||||
|
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||||
|
|
||||||
|
latents = npz["latents"]
|
||||||
|
original_size = npz["original_size"].tolist()
|
||||||
|
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||||
|
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||||
|
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||||
|
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||||
|
|
||||||
|
def save_latents_to_disk(
|
||||||
|
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
|
||||||
|
):
|
||||||
|
kwargs = {}
|
||||||
|
if flipped_latents_tensor is not None:
|
||||||
|
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||||
|
if alpha_mask is not None:
|
||||||
|
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||||
|
np.savez(
|
||||||
|
npz_path,
|
||||||
|
latents=latents_tensor.float().cpu().numpy(),
|
||||||
|
original_size=np.array(original_size),
|
||||||
|
crop_ltrb=np.array(crop_ltrb),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
139
library/strategy_sd.py
Normal file
139
library/strategy_sd.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
from library import train_util
|
||||||
|
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||||
|
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||||
|
|
||||||
|
|
||||||
|
class SdTokenizeStrategy(TokenizeStrategy):
|
||||||
|
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
|
||||||
|
"""
|
||||||
|
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
|
||||||
|
if v2:
|
||||||
|
self.tokenizer = self._load_tokenizer(
|
||||||
|
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
|
||||||
|
if max_length is None:
|
||||||
|
self.max_length = self.tokenizer.model_max_length
|
||||||
|
else:
|
||||||
|
self.max_length = max_length + 2
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
text = [text] if isinstance(text, str) else text
|
||||||
|
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||||
|
|
||||||
|
|
||||||
|
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||||
|
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||||
|
self.clip_skip = clip_skip
|
||||||
|
|
||||||
|
def encode_tokens(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
text_encoder = models[0]
|
||||||
|
tokens = tokens[0]
|
||||||
|
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
||||||
|
|
||||||
|
# tokens: b,n,77
|
||||||
|
b_size = tokens.size()[0]
|
||||||
|
max_token_length = tokens.size()[1] * tokens.size()[2]
|
||||||
|
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||||
|
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||||
|
|
||||||
|
if self.clip_skip is None:
|
||||||
|
encoder_hidden_states = text_encoder(tokens)[0]
|
||||||
|
else:
|
||||||
|
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||||
|
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||||
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||||
|
|
||||||
|
# bs*3, 77, 768 or 1024
|
||||||
|
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||||
|
|
||||||
|
if max_token_length != model_max_length:
|
||||||
|
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
||||||
|
if not v1:
|
||||||
|
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||||
|
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, max_token_length, model_max_length):
|
||||||
|
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||||
|
if i > 0:
|
||||||
|
for j in range(len(chunk)):
|
||||||
|
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
||||||
|
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||||
|
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||||
|
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||||
|
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||||
|
else:
|
||||||
|
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
|
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, max_token_length, model_max_length):
|
||||||
|
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||||
|
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
return [encoder_hidden_states]
|
||||||
|
|
||||||
|
|
||||||
|
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||||
|
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||||
|
# and we keep the old npz for the backward compatibility.
|
||||||
|
|
||||||
|
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||||
|
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||||
|
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||||
|
|
||||||
|
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||||
|
self.sd = sd
|
||||||
|
self.suffix = (
|
||||||
|
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||||
|
# does not include old npz
|
||||||
|
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
|
||||||
|
if len(npz_file) == 0:
|
||||||
|
return None, None
|
||||||
|
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||||
|
return int(w), int(h)
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||||
|
# support old .npz
|
||||||
|
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||||
|
if os.path.exists(old_npz_file):
|
||||||
|
return old_npz_file
|
||||||
|
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
|
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||||
|
|
||||||
|
# TODO remove circular dependency for ImageInfo
|
||||||
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
||||||
|
vae_device = vae.device
|
||||||
|
vae_dtype = vae.dtype
|
||||||
|
|
||||||
|
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||||
|
|
||||||
|
if not train_util.HIGH_VRAM:
|
||||||
|
train_util.clean_memory_on_device(vae.device)
|
||||||
229
library/strategy_sd3.py
Normal file
229
library/strategy_sd3.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from library import sd3_utils, train_util
|
||||||
|
from library import sd3_models
|
||||||
|
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||||
|
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
|
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3TokenizeStrategy(TokenizeStrategy):
|
||||||
|
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||||
|
self.t5xxl_max_length = t5xxl_max_length
|
||||||
|
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
text = [text] if isinstance(text, str) else text
|
||||||
|
|
||||||
|
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
|
||||||
|
l_tokens = l_tokens["input_ids"]
|
||||||
|
g_tokens = g_tokens["input_ids"]
|
||||||
|
t5_tokens = t5_tokens["input_ids"]
|
||||||
|
|
||||||
|
return [l_tokens, g_tokens, t5_tokens]
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def encode_tokens(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
clip_l, clip_g, t5xxl = models
|
||||||
|
|
||||||
|
l_tokens, g_tokens, t5_tokens = tokens
|
||||||
|
if l_tokens is None:
|
||||||
|
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||||
|
lg_out = None
|
||||||
|
else:
|
||||||
|
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||||
|
l_out, l_pooled = clip_l(l_tokens)
|
||||||
|
g_out, g_pooled = clip_g(g_tokens)
|
||||||
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
|
||||||
|
if t5xxl is not None and t5_tokens is not None:
|
||||||
|
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
|
||||||
|
else:
|
||||||
|
t5_out = None
|
||||||
|
|
||||||
|
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||||
|
return [lg_out, t5_out, lg_pooled]
|
||||||
|
|
||||||
|
def concat_encodings(
|
||||||
|
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
if t5_out is None:
|
||||||
|
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
||||||
|
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||||
|
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||||
|
) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||||
|
|
||||||
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
|
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
|
|
||||||
|
def is_disk_cached_outputs_expected(self, abs_path: str):
|
||||||
|
if not self.cache_to_disk:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
|
||||||
|
return False
|
||||||
|
if self.skip_disk_cache_validity_check:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
npz = np.load(self.get_outputs_npz_path(abs_path))
|
||||||
|
if "clip_l" not in npz or "clip_g" not in npz:
|
||||||
|
return False
|
||||||
|
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
|
||||||
|
return False
|
||||||
|
# t5xxl is optional
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
|
data = np.load(npz_path)
|
||||||
|
lg_out = data["lg_out"]
|
||||||
|
lg_pooled = data["lg_pooled"]
|
||||||
|
t5_out = data["t5_out"] if "t5_out" in data else None
|
||||||
|
return [lg_out, t5_out, lg_pooled]
|
||||||
|
|
||||||
|
def cache_batch_outputs(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||||
|
):
|
||||||
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
|
clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
|
||||||
|
with torch.no_grad():
|
||||||
|
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
|
||||||
|
)
|
||||||
|
|
||||||
|
if lg_out.dtype == torch.bfloat16:
|
||||||
|
lg_out = lg_out.float()
|
||||||
|
if lg_pooled.dtype == torch.bfloat16:
|
||||||
|
lg_pooled = lg_pooled.float()
|
||||||
|
if t5_out is not None and t5_out.dtype == torch.bfloat16:
|
||||||
|
t5_out = t5_out.float()
|
||||||
|
|
||||||
|
lg_out = lg_out.cpu().numpy()
|
||||||
|
lg_pooled = lg_pooled.cpu().numpy()
|
||||||
|
if t5_out is not None:
|
||||||
|
t5_out = t5_out.cpu().numpy()
|
||||||
|
|
||||||
|
for i, info in enumerate(infos):
|
||||||
|
lg_out_i = lg_out[i]
|
||||||
|
t5_out_i = t5_out[i] if t5_out is not None else None
|
||||||
|
lg_pooled_i = lg_pooled[i]
|
||||||
|
|
||||||
|
if self.cache_to_disk:
|
||||||
|
kwargs = {}
|
||||||
|
if t5_out is not None:
|
||||||
|
kwargs["t5_out"] = t5_out_i
|
||||||
|
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
|
||||||
|
else:
|
||||||
|
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||||
|
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||||
|
|
||||||
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||||
|
|
||||||
|
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||||
|
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
|
||||||
|
if len(npz_file) == 0:
|
||||||
|
return None, None
|
||||||
|
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||||
|
return int(w), int(h)
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||||
|
return (
|
||||||
|
os.path.splitext(absolute_path)[0]
|
||||||
|
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||||
|
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
|
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||||
|
|
||||||
|
# TODO remove circular dependency for ImageInfo
|
||||||
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||||
|
vae_device = vae.device
|
||||||
|
vae_dtype = vae.dtype
|
||||||
|
|
||||||
|
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||||
|
|
||||||
|
if not train_util.HIGH_VRAM:
|
||||||
|
train_util.clean_memory_on_device(vae.device)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test code for Sd3TokenizeStrategy
|
||||||
|
# tokenizer = sd3_models.SD3Tokenizer()
|
||||||
|
strategy = Sd3TokenizeStrategy(256)
|
||||||
|
text = "hello world"
|
||||||
|
|
||||||
|
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||||
|
# print(l_tokens.shape)
|
||||||
|
print(l_tokens)
|
||||||
|
print(g_tokens)
|
||||||
|
print(t5_tokens)
|
||||||
|
|
||||||
|
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||||
|
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
t5_tokens_2 = strategy.t5xxl(
|
||||||
|
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
print(l_tokens_2)
|
||||||
|
print(g_tokens_2)
|
||||||
|
print(t5_tokens_2)
|
||||||
|
|
||||||
|
# compare
|
||||||
|
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
||||||
|
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
||||||
|
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
||||||
|
|
||||||
|
text = ",".join(["hello world! this is long text"] * 50)
|
||||||
|
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||||
|
print(l_tokens)
|
||||||
|
print(g_tokens)
|
||||||
|
print(t5_tokens)
|
||||||
|
|
||||||
|
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
||||||
|
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
||||||
|
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
||||||
247
library/strategy_sdxl.py
Normal file
247
library/strategy_sdxl.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||||
|
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||||
|
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
|
|
||||||
|
|
||||||
|
class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||||
|
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||||
|
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
||||||
|
|
||||||
|
if max_length is None:
|
||||||
|
self.max_length = self.tokenizer1.model_max_length
|
||||||
|
else:
|
||||||
|
self.max_length = max_length + 2
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
text = [text] if isinstance(text, str) else text
|
||||||
|
return (
|
||||||
|
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
|
||||||
|
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _pool_workaround(
|
||||||
|
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
||||||
|
instead of the hidden states for the EOS token
|
||||||
|
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
||||||
|
|
||||||
|
Original code from CLIP's pooling function:
|
||||||
|
|
||||||
|
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||||
|
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
|
pooled_output = last_hidden_state[
|
||||||
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# input_ids: b*n,77
|
||||||
|
# find index for EOS token
|
||||||
|
|
||||||
|
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
|
||||||
|
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||||
|
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||||
|
|
||||||
|
# Create a mask where the EOS tokens are
|
||||||
|
eos_token_mask = (input_ids == eos_token_id).int()
|
||||||
|
|
||||||
|
# Use argmax to find the last index of the EOS token for each element in the batch
|
||||||
|
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
|
||||||
|
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||||
|
|
||||||
|
# get hidden states for EOS token
|
||||||
|
pooled_output = last_hidden_state[
|
||||||
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
|
||||||
|
]
|
||||||
|
|
||||||
|
# apply projection: projection may be of different dtype than last_hidden_state
|
||||||
|
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
|
||||||
|
pooled_output = pooled_output.to(last_hidden_state.dtype)
|
||||||
|
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
def _get_hidden_states_sdxl(
|
||||||
|
self,
|
||||||
|
input_ids1: torch.Tensor,
|
||||||
|
input_ids2: torch.Tensor,
|
||||||
|
tokenizer1: CLIPTokenizer,
|
||||||
|
tokenizer2: CLIPTokenizer,
|
||||||
|
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
|
||||||
|
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
|
||||||
|
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
|
||||||
|
):
|
||||||
|
# input_ids: b,n,77 -> b*n, 77
|
||||||
|
b_size = input_ids1.size()[0]
|
||||||
|
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||||
|
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||||
|
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||||
|
input_ids1 = input_ids1.to(text_encoder1.device)
|
||||||
|
input_ids2 = input_ids2.to(text_encoder2.device)
|
||||||
|
|
||||||
|
# text_encoder1
|
||||||
|
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||||
|
hidden_states1 = enc_out["hidden_states"][11]
|
||||||
|
|
||||||
|
# text_encoder2
|
||||||
|
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||||
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
|
|
||||||
|
# pool2 = enc_out["text_embeds"]
|
||||||
|
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
|
||||||
|
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||||
|
|
||||||
|
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||||
|
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||||
|
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||||
|
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||||
|
|
||||||
|
if max_token_length is not None:
|
||||||
|
# bs*3, 77, 768 or 1024
|
||||||
|
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
|
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, max_token_length, tokenizer1.model_max_length):
|
||||||
|
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||||
|
hidden_states1 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||||
|
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||||
|
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||||
|
# this causes an error:
|
||||||
|
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||||
|
# if i > 1:
|
||||||
|
# for j in range(len(chunk)): # batch_size
|
||||||
|
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||||
|
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||||
|
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||||
|
hidden_states2 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
# pool はnの最初のものを使う
|
||||||
|
pool2 = pool2[::n_size]
|
||||||
|
|
||||||
|
return hidden_states1, hidden_states2, pool2
|
||||||
|
|
||||||
|
def encode_tokens(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokenize_strategy: TokenizeStrategy
|
||||||
|
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
|
||||||
|
tokens: List of tokens, for text_encoder1 and text_encoder2
|
||||||
|
"""
|
||||||
|
if len(models) == 2:
|
||||||
|
text_encoder1, text_encoder2 = models
|
||||||
|
unwrapped_text_encoder2 = None
|
||||||
|
else:
|
||||||
|
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
|
||||||
|
tokens1, tokens2 = tokens
|
||||||
|
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
|
||||||
|
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
|
||||||
|
|
||||||
|
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
|
||||||
|
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
|
||||||
|
)
|
||||||
|
return [hidden_states1, hidden_states2, pool2]
|
||||||
|
|
||||||
|
|
||||||
|
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||||
|
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||||
|
) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||||
|
|
||||||
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
|
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
|
|
||||||
|
def is_disk_cached_outputs_expected(self, abs_path: str):
|
||||||
|
if not self.cache_to_disk:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
|
||||||
|
return False
|
||||||
|
if self.skip_disk_cache_validity_check:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
npz = np.load(self.get_outputs_npz_path(abs_path))
|
||||||
|
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
|
data = np.load(npz_path)
|
||||||
|
hidden_state1 = data["hidden_state1"]
|
||||||
|
hidden_state2 = data["hidden_state2"]
|
||||||
|
pool2 = data["pool2"]
|
||||||
|
return [hidden_state1, hidden_state2, pool2]
|
||||||
|
|
||||||
|
def cache_batch_outputs(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||||
|
):
|
||||||
|
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||||
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
|
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
|
||||||
|
with torch.no_grad():
|
||||||
|
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, models, [tokens1, tokens2]
|
||||||
|
)
|
||||||
|
if hidden_state1.dtype == torch.bfloat16:
|
||||||
|
hidden_state1 = hidden_state1.float()
|
||||||
|
if hidden_state2.dtype == torch.bfloat16:
|
||||||
|
hidden_state2 = hidden_state2.float()
|
||||||
|
if pool2.dtype == torch.bfloat16:
|
||||||
|
pool2 = pool2.float()
|
||||||
|
|
||||||
|
hidden_state1 = hidden_state1.cpu().numpy()
|
||||||
|
hidden_state2 = hidden_state2.cpu().numpy()
|
||||||
|
pool2 = pool2.cpu().numpy()
|
||||||
|
|
||||||
|
for i, info in enumerate(infos):
|
||||||
|
hidden_state1_i = hidden_state1[i]
|
||||||
|
hidden_state2_i = hidden_state2[i]
|
||||||
|
pool2_i = pool2[i]
|
||||||
|
|
||||||
|
if self.cache_to_disk:
|
||||||
|
np.savez(
|
||||||
|
info.text_encoder_outputs_npz,
|
||||||
|
hidden_state1=hidden_state1_i,
|
||||||
|
hidden_state2=hidden_state2_i,
|
||||||
|
pool2=pool2_i,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||||
@@ -12,6 +12,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
@@ -34,6 +35,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
@@ -81,10 +83,6 @@ logger = logging.getLogger(__name__)
|
|||||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
||||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
|
||||||
|
|
||||||
HIGH_VRAM = False
|
HIGH_VRAM = False
|
||||||
|
|
||||||
# checkpointファイル名
|
# checkpointファイル名
|
||||||
@@ -148,18 +146,24 @@ class ImageInfo:
|
|||||||
self.image_size: Tuple[int, int] = None
|
self.image_size: Tuple[int, int] = None
|
||||||
self.resized_size: Tuple[int, int] = None
|
self.resized_size: Tuple[int, int] = None
|
||||||
self.bucket_reso: Tuple[int, int] = None
|
self.bucket_reso: Tuple[int, int] = None
|
||||||
self.latents: torch.Tensor = None
|
self.latents: Optional[torch.Tensor] = None
|
||||||
self.latents_flipped: torch.Tensor = None
|
self.latents_flipped: Optional[torch.Tensor] = None
|
||||||
self.latents_npz: str = None
|
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||||
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||||
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
|
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||||
self.cond_img_path: str = None
|
None # crop left top right bottom in original pixel size, not latents size
|
||||||
|
)
|
||||||
|
self.cond_img_path: Optional[str] = None
|
||||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||||
# SDXL, optional
|
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||||
self.text_encoder_outputs_npz: Optional[str] = None
|
|
||||||
|
# new
|
||||||
|
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||||
|
# old
|
||||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||||
|
|
||||||
|
|
||||||
@@ -359,47 +363,6 @@ class AugHelper:
|
|||||||
return self.color_aug if use_color_aug else None
|
return self.color_aug if use_color_aug else None
|
||||||
|
|
||||||
|
|
||||||
class LatentsCachingStrategy:
|
|
||||||
_strategy = None # strategy instance: actual strategy class
|
|
||||||
|
|
||||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
|
||||||
self._cache_to_disk = cache_to_disk
|
|
||||||
self._batch_size = batch_size
|
|
||||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_strategy(cls, strategy):
|
|
||||||
if cls._strategy is not None:
|
|
||||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
|
||||||
cls._strategy = strategy
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
|
||||||
return cls._strategy
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_to_disk(self):
|
|
||||||
return self._cache_to_disk
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_size(self):
|
|
||||||
return self._batch_size
|
|
||||||
|
|
||||||
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(
|
|
||||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
|
||||||
) -> bool:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSubset:
|
class BaseSubset:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset):
|
|||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
|
||||||
max_token_length: int,
|
|
||||||
resolution: Optional[Tuple[int, int]],
|
resolution: Optional[Tuple[int, int]],
|
||||||
network_multiplier: float,
|
network_multiplier: float,
|
||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
|
||||||
|
|
||||||
self.max_token_length = max_token_length
|
|
||||||
# width/height is used when enable_bucket==False
|
# width/height is used when enable_bucket==False
|
||||||
self.width, self.height = (None, None) if resolution is None else resolution
|
self.width, self.height = (None, None) if resolution is None else resolution
|
||||||
self.network_multiplier = network_multiplier
|
self.network_multiplier = network_multiplier
|
||||||
@@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.bucket_no_upscale = None
|
self.bucket_no_upscale = None
|
||||||
self.bucket_info = None # for metadata
|
self.bucket_info = None # for metadata
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
|
|
||||||
|
|
||||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||||
|
|
||||||
self.current_step: int = 0
|
self.current_step: int = 0
|
||||||
@@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# caching
|
# caching
|
||||||
self.caching_mode = None # None, 'latents', 'text'
|
self.caching_mode = None # None, 'latents', 'text'
|
||||||
|
|
||||||
|
self.tokenize_strategy = None
|
||||||
|
self.text_encoder_output_caching_strategy = None
|
||||||
|
self.latents_caching_strategy = None
|
||||||
|
|
||||||
|
def set_current_strategies(self):
|
||||||
|
self.tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||||
|
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||||
|
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||||
|
|
||||||
def set_seed(self, seed):
|
def set_seed(self, seed):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
for batch_index in range(batch_count):
|
for batch_index in range(batch_count):
|
||||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||||
|
|
||||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
|
||||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
|
||||||
#
|
|
||||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
|
||||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
|
||||||
# # そのためバッチサイズを画像種類までに制限する
|
|
||||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
|
||||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
|
||||||
# num_of_image_types = len(set(bucket))
|
|
||||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
|
||||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
|
||||||
# # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
|
||||||
# for batch_index in range(batch_count):
|
|
||||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
|
||||||
# ↑ここまで
|
|
||||||
|
|
||||||
self.shuffle_buckets()
|
self.shuffle_buckets()
|
||||||
self._length = len(self.buckets_indices)
|
self._length = len(self.buckets_indices)
|
||||||
|
|
||||||
@@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
|
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||||
r"""
|
r"""
|
||||||
a brand new method to cache latents. This method caches latents with caching strategy.
|
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||||
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||||
"""
|
"""
|
||||||
logger.info("caching latents with caching strategy.")
|
logger.info("caching latents with caching strategy.")
|
||||||
|
caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||||
image_infos = list(self.image_data.values())
|
image_infos = list(self.image_data.values())
|
||||||
|
|
||||||
# sort by resolution
|
# sort by resolution
|
||||||
@@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
logger.info("caching latents...")
|
logger.info("caching latents...")
|
||||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
@@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
|
|
||||||
|
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||||
|
r"""
|
||||||
|
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
|
||||||
|
"""
|
||||||
|
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||||
|
text_encoding_strategy = TextEncodingStrategy.get_strategy()
|
||||||
|
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||||
|
batch_size = caching_strategy.batch_size or self.batch_size
|
||||||
|
|
||||||
|
# if cache to disk, don't cache TE outputs in non-main process
|
||||||
|
if caching_strategy.cache_to_disk and not is_main_process:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("caching Text Encoder outputs with caching strategy.")
|
||||||
|
image_infos = list(self.image_data.values())
|
||||||
|
|
||||||
|
# split by resolution
|
||||||
|
batches = []
|
||||||
|
batch = []
|
||||||
|
logger.info("checking cache validity...")
|
||||||
|
for info in tqdm(image_infos):
|
||||||
|
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||||
|
|
||||||
|
# check disk cache exists and size of latents
|
||||||
|
if caching_strategy.cache_to_disk:
|
||||||
|
info.text_encoder_outputs_npz = te_out_npz
|
||||||
|
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||||
|
if cache_available: # do not add to batch
|
||||||
|
continue
|
||||||
|
|
||||||
|
batch.append(info)
|
||||||
|
|
||||||
|
# if number of data in batch is enough, flush the batch
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
batches.append(batch)
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
if len(batch) > 0:
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
if len(batches) == 0:
|
||||||
|
logger.info("no Text Encoder outputs to cache")
|
||||||
|
return
|
||||||
|
|
||||||
|
# iterate batches
|
||||||
|
logger.info("caching Text Encoder outputs...")
|
||||||
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
|
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
|
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
|
||||||
|
|
||||||
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
|
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
|
||||||
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
|
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
|
||||||
# to support SD1/2, it needs a flag for v2, but it is postponed
|
# to support SD1/2, it needs a flag for v2, but it is postponed
|
||||||
@@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
logger.info("caching text encoder outputs.")
|
logger.info("caching text encoder outputs.")
|
||||||
|
|
||||||
|
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||||
|
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
|
|
||||||
@@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||||
batch.append((info, input_ids1, input_ids2))
|
batch.append((info, input_ids1, input_ids2))
|
||||||
else:
|
else:
|
||||||
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
|
||||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||||
|
|
||||||
if len(batch) >= batch_size:
|
if len(batch) >= batch_size:
|
||||||
@@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
loss_weights = []
|
loss_weights = []
|
||||||
captions = []
|
captions = []
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
input_ids2_list = []
|
|
||||||
latents_list = []
|
latents_list = []
|
||||||
alpha_mask_list = []
|
alpha_mask_list = []
|
||||||
images = []
|
images = []
|
||||||
@@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
crop_top_lefts = []
|
crop_top_lefts = []
|
||||||
target_sizes_hw = []
|
target_sizes_hw = []
|
||||||
flippeds = [] # 変数名が微妙
|
flippeds = [] # 変数名が微妙
|
||||||
text_encoder_outputs1_list = []
|
text_encoder_outputs_list = []
|
||||||
text_encoder_outputs2_list = []
|
|
||||||
text_encoder_pool2_list = []
|
|
||||||
|
|
||||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||||
image_info = self.image_data[image_key]
|
image_info = self.image_data[image_key]
|
||||||
subset = self.image_to_subset[image_key]
|
subset = self.image_to_subset[image_key]
|
||||||
loss_weights.append(
|
|
||||||
self.prior_loss_weight if image_info.is_reg else 1.0
|
# in case of fine tuning, is_reg is always False
|
||||||
) # in case of fine tuning, is_reg is always False
|
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||||
|
|
||||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||||
|
|
||||||
@@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
image = None
|
image = None
|
||||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||||
|
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
|
||||||
|
)
|
||||||
if flipped:
|
if flipped:
|
||||||
latents = flipped_latents
|
latents = flipped_latents
|
||||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
||||||
@@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# captionとtext encoder outputを処理する
|
# captionとtext encoder outputを処理する
|
||||||
caption = image_info.caption # default
|
caption = image_info.caption # default
|
||||||
if image_info.text_encoder_outputs1 is not None:
|
|
||||||
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
|
tokenization_required = (
|
||||||
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
|
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
|
||||||
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
|
)
|
||||||
captions.append(caption)
|
text_encoder_outputs = None
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
if image_info.text_encoder_outputs is not None:
|
||||||
|
# cached
|
||||||
|
text_encoder_outputs = image_info.text_encoder_outputs
|
||||||
elif image_info.text_encoder_outputs_npz is not None:
|
elif image_info.text_encoder_outputs_npz is not None:
|
||||||
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
|
# on disk
|
||||||
|
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||||
image_info.text_encoder_outputs_npz
|
image_info.text_encoder_outputs_npz
|
||||||
)
|
)
|
||||||
text_encoder_outputs1_list.append(text_encoder_outputs1)
|
|
||||||
text_encoder_outputs2_list.append(text_encoder_outputs2)
|
|
||||||
text_encoder_pool2_list.append(text_encoder_pool2)
|
|
||||||
captions.append(caption)
|
|
||||||
else:
|
else:
|
||||||
|
tokenization_required = True
|
||||||
|
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||||
|
|
||||||
|
if tokenization_required:
|
||||||
caption = self.process_caption(subset, image_info.caption)
|
caption = self.process_caption(subset, image_info.caption)
|
||||||
if self.XTI_layers:
|
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||||
caption_layer = []
|
# if self.XTI_layers:
|
||||||
for layer in self.XTI_layers:
|
# caption_layer = []
|
||||||
token_strings_from = " ".join(self.token_strings)
|
# for layer in self.XTI_layers:
|
||||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
# token_strings_from = " ".join(self.token_strings)
|
||||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||||
caption_layer.append(caption_)
|
# caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||||
captions.append(caption_layer)
|
# caption_layer.append(caption_)
|
||||||
else:
|
# captions.append(caption_layer)
|
||||||
captions.append(caption)
|
# else:
|
||||||
|
# captions.append(caption)
|
||||||
|
|
||||||
if not self.token_padding_disabled: # this option might be omitted in future
|
# if not self.token_padding_disabled: # this option might be omitted in future
|
||||||
# TODO get_input_ids must support SD3
|
# # TODO get_input_ids must support SD3
|
||||||
if self.XTI_layers:
|
# if self.XTI_layers:
|
||||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||||
else:
|
# else:
|
||||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||||
input_ids_list.append(token_caption)
|
# input_ids_list.append(token_caption)
|
||||||
|
|
||||||
if len(self.tokenizers) > 1:
|
# if len(self.tokenizers) > 1:
|
||||||
if self.XTI_layers:
|
# if self.XTI_layers:
|
||||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||||
else:
|
# else:
|
||||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||||
input_ids2_list.append(token_caption2)
|
# input_ids2_list.append(token_caption2)
|
||||||
|
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
captions.append(caption)
|
||||||
|
|
||||||
|
def none_or_stack_elements(tensors_list, converter):
|
||||||
|
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
|
||||||
|
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
|
||||||
|
return None
|
||||||
|
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||||
|
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
|
||||||
if len(text_encoder_outputs1_list) == 0:
|
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
|
||||||
if self.token_padding_disabled:
|
|
||||||
# padding=True means pad in the batch
|
|
||||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
|
||||||
if len(self.tokenizers) > 1:
|
|
||||||
example["input_ids2"] = self.tokenizer[1](
|
|
||||||
captions, padding=True, truncation=True, return_tensors="pt"
|
|
||||||
).input_ids
|
|
||||||
else:
|
|
||||||
example["input_ids2"] = None
|
|
||||||
else:
|
|
||||||
example["input_ids"] = torch.stack(input_ids_list)
|
|
||||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
|
||||||
example["text_encoder_outputs1_list"] = None
|
|
||||||
example["text_encoder_outputs2_list"] = None
|
|
||||||
example["text_encoder_pool2_list"] = None
|
|
||||||
else:
|
|
||||||
example["input_ids"] = None
|
|
||||||
example["input_ids2"] = None
|
|
||||||
# # for assertion
|
|
||||||
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
|
||||||
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
|
||||||
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
|
||||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
|
||||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
|
||||||
|
|
||||||
# if one of alpha_masks is not None, we need to replace None with ones
|
# if one of alpha_masks is not None, we need to replace None with ones
|
||||||
none_or_not = [x is None for x in alpha_mask_list]
|
none_or_not = [x is None for x in alpha_mask_list]
|
||||||
@@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
self,
|
self,
|
||||||
subsets: Sequence[DreamBoothSubset],
|
subsets: Sequence[DreamBoothSubset],
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
tokenizer,
|
|
||||||
max_token_length,
|
|
||||||
resolution,
|
resolution,
|
||||||
network_multiplier: float,
|
network_multiplier: float,
|
||||||
enable_bucket: bool,
|
enable_bucket: bool,
|
||||||
@@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
prior_loss_weight: float,
|
prior_loss_weight: float,
|
||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||||
|
|
||||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||||
|
|
||||||
@@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
# new caching: get image size from cache files
|
# new caching: get image size from cache files
|
||||||
strategy = LatentsCachingStrategy.get_strategy()
|
strategy = LatentsCachingStrategy.get_strategy()
|
||||||
if strategy is not None:
|
if strategy is not None:
|
||||||
logger.info("get image size from cache files")
|
logger.info("get image size from name of cache files")
|
||||||
size_set_count = 0
|
size_set_count = 0
|
||||||
for i, img_path in enumerate(tqdm(img_paths)):
|
for i, img_path in enumerate(tqdm(img_paths)):
|
||||||
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
|
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
|
||||||
if w is not None and h is not None:
|
if w is not None and h is not None:
|
||||||
sizes[i] = [w, h]
|
sizes[i] = [w, h]
|
||||||
size_set_count += 1
|
size_set_count += 1
|
||||||
@@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self,
|
self,
|
||||||
subsets: Sequence[FineTuningSubset],
|
subsets: Sequence[FineTuningSubset],
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
tokenizer,
|
|
||||||
max_token_length,
|
|
||||||
resolution,
|
resolution,
|
||||||
network_multiplier: float,
|
network_multiplier: float,
|
||||||
enable_bucket: bool,
|
enable_bucket: bool,
|
||||||
@@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
bucket_no_upscale: bool,
|
bucket_no_upscale: bool,
|
||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
@@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset):
|
|||||||
self,
|
self,
|
||||||
subsets: Sequence[ControlNetSubset],
|
subsets: Sequence[ControlNetSubset],
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
tokenizer,
|
|
||||||
max_token_length,
|
|
||||||
resolution,
|
resolution,
|
||||||
network_multiplier: float,
|
network_multiplier: float,
|
||||||
enable_bucket: bool,
|
enable_bucket: bool,
|
||||||
@@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
bucket_no_upscale: bool,
|
bucket_no_upscale: bool,
|
||||||
debug_dataset: float,
|
debug_dataset: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||||
|
|
||||||
db_subsets = []
|
db_subsets = []
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
@@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset):
|
|||||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||||
db_subsets,
|
db_subsets,
|
||||||
batch_size,
|
batch_size,
|
||||||
tokenizer,
|
|
||||||
max_token_length,
|
|
||||||
resolution,
|
resolution,
|
||||||
network_multiplier,
|
network_multiplier,
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
@@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset):
|
|||||||
|
|
||||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||||
|
|
||||||
|
def set_current_strategies(self):
|
||||||
|
return self.dreambooth_dataset_delegate.set_current_strategies()
|
||||||
|
|
||||||
def make_buckets(self):
|
def make_buckets(self):
|
||||||
self.dreambooth_dataset_delegate.make_buckets()
|
self.dreambooth_dataset_delegate.make_buckets()
|
||||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||||
@@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset):
|
|||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||||
|
|
||||||
|
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||||
|
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
|
||||||
|
|
||||||
|
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||||
|
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.dreambooth_dataset_delegate.__len__()
|
return self.dreambooth_dataset_delegate.__len__()
|
||||||
|
|
||||||
@@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
# for dataset in self.datasets:
|
# for dataset in self.datasets:
|
||||||
# dataset.make_buckets()
|
# dataset.make_buckets()
|
||||||
|
|
||||||
|
def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy):
|
||||||
|
"""
|
||||||
|
DataLoader is run in multiple processes, so we need to set the strategy manually.
|
||||||
|
"""
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_text_encoder_output_caching_strategy(strategy)
|
||||||
|
|
||||||
def enable_XTI(self, *args, **kwargs):
|
def enable_XTI(self, *args, **kwargs):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.enable_XTI(*args, **kwargs)
|
dataset.enable_XTI(*args, **kwargs)
|
||||||
@@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
logger.info(f"[Dataset {i}]")
|
logger.info(f"[Dataset {i}]")
|
||||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||||
|
|
||||||
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
|
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||||
for i, dataset in enumerate(self.datasets):
|
for i, dataset in enumerate(self.datasets):
|
||||||
logger.info(f"[Dataset {i}]")
|
logger.info(f"[Dataset {i}]")
|
||||||
dataset.new_cache_latents(is_main_process, strategy)
|
dataset.new_cache_latents(model, is_main_process)
|
||||||
|
|
||||||
def cache_text_encoder_outputs(
|
def cache_text_encoder_outputs(
|
||||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||||
@@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||||
|
for i, dataset in enumerate(self.datasets):
|
||||||
|
logger.info(f"[Dataset {i}]")
|
||||||
|
dataset.new_cache_text_encoder_outputs(models, is_main_process)
|
||||||
|
|
||||||
def set_caching_mode(self, caching_mode):
|
def set_caching_mode(self, caching_mode):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_caching_mode(caching_mode)
|
dataset.set_caching_mode(caching_mode)
|
||||||
@@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
def is_text_encoder_output_cacheable(self) -> bool:
|
def is_text_encoder_output_cacheable(self) -> bool:
|
||||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||||
|
|
||||||
|
def set_current_strategies(self):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_current_strategies()
|
||||||
|
|
||||||
def set_current_epoch(self, epoch):
|
def set_current_epoch(self, epoch):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_current_epoch(epoch)
|
dataset.set_current_epoch(epoch)
|
||||||
@@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
|||||||
|
|
||||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||||
# TODO update to use CachingStrategy
|
# TODO update to use CachingStrategy
|
||||||
def load_latents_from_disk(
|
# def load_latents_from_disk(
|
||||||
npz_path,
|
# npz_path,
|
||||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
npz = np.load(npz_path)
|
# npz = np.load(npz_path)
|
||||||
if "latents" not in npz:
|
# if "latents" not in npz:
|
||||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
# raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||||
|
|
||||||
latents = npz["latents"]
|
# latents = npz["latents"]
|
||||||
original_size = npz["original_size"].tolist()
|
# original_size = npz["original_size"].tolist()
|
||||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
# crop_ltrb = npz["crop_ltrb"].tolist()
|
||||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||||
|
|
||||||
|
|
||||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||||
kwargs = {}
|
# kwargs = {}
|
||||||
if flipped_latents_tensor is not None:
|
# if flipped_latents_tensor is not None:
|
||||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||||
if alpha_mask is not None:
|
# if alpha_mask is not None:
|
||||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||||
np.savez(
|
# np.savez(
|
||||||
npz_path,
|
# npz_path,
|
||||||
latents=latents_tensor.float().cpu().numpy(),
|
# latents=latents_tensor.float().cpu().numpy(),
|
||||||
original_size=np.array(original_size),
|
# original_size=np.array(original_size),
|
||||||
crop_ltrb=np.array(crop_ltrb),
|
# crop_ltrb=np.array(crop_ltrb),
|
||||||
**kwargs,
|
# **kwargs,
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
def debug_dataset(train_dataset, show_input_ids=False):
|
def debug_dataset(train_dataset, show_input_ids=False):
|
||||||
@@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
example = train_dataset[idx]
|
example = train_dataset[idx]
|
||||||
if example["latents"] is not None:
|
if example["latents"] is not None:
|
||||||
logger.info(f"sample has latents from npz file: {example['latents'].size()}")
|
logger.info(f"sample has latents from npz file: {example['latents'].size()}")
|
||||||
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
|
for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
example["image_keys"],
|
example["image_keys"],
|
||||||
example["captions"],
|
example["captions"],
|
||||||
example["loss_weights"],
|
example["loss_weights"],
|
||||||
example["input_ids"],
|
# example["input_ids"],
|
||||||
example["original_sizes_hw"],
|
example["original_sizes_hw"],
|
||||||
example["crop_top_lefts"],
|
example["crop_top_lefts"],
|
||||||
example["target_sizes_hw"],
|
example["target_sizes_hw"],
|
||||||
@@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
if "network_multipliers" in example:
|
if "network_multipliers" in example:
|
||||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||||
|
|
||||||
if show_input_ids:
|
# if show_input_ids:
|
||||||
logger.info(f"input ids: {iid}")
|
# logger.info(f"input ids: {iid}")
|
||||||
if "input_ids2" in example:
|
# if "input_ids2" in example:
|
||||||
logger.info(f"input ids2: {example['input_ids2'][j]}")
|
# logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||||
if example["images"] is not None:
|
if example["images"] is not None:
|
||||||
im = example["images"][j]
|
im = example["images"][j]
|
||||||
logger.info(f"image size: {im.size()}")
|
logger.info(f"image size: {im.size()}")
|
||||||
@@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive):
|
|||||||
|
|
||||||
|
|
||||||
class MinimalDataset(BaseDataset):
|
class MinimalDataset(BaseDataset):
|
||||||
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
|
def __init__(self, resolution, network_multiplier, debug_dataset=False):
|
||||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||||
|
|
||||||
self.num_train_images = 0 # update in subclass
|
self.num_train_images = 0 # update in subclass
|
||||||
self.num_reg_images = 0 # update in subclass
|
self.num_reg_images = 0 # update in subclass
|
||||||
@@ -2773,14 +2783,15 @@ def cache_batch_latents(
|
|||||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||||
|
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
save_latents_to_disk(
|
# save_latents_to_disk(
|
||||||
info.latents_npz,
|
# info.latents_npz,
|
||||||
latent,
|
# latent,
|
||||||
info.latents_original_size,
|
# info.latents_original_size,
|
||||||
info.latents_crop_ltrb,
|
# info.latents_crop_ltrb,
|
||||||
flipped_latent,
|
# flipped_latent,
|
||||||
alpha_mask,
|
# alpha_mask,
|
||||||
)
|
# )
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
info.latents = latent
|
info.latents = latent
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
@@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(args: argparse.Namespace):
|
|
||||||
logger.info("prepare tokenizer")
|
|
||||||
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
|
||||||
|
|
||||||
tokenizer: CLIPTokenizer = None
|
|
||||||
if args.tokenizer_cache_dir:
|
|
||||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
|
||||||
if os.path.exists(local_tokenizer_path):
|
|
||||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
if args.v2:
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
|
||||||
else:
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
|
||||||
|
|
||||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
|
||||||
logger.info(f"update token length: {args.max_token_length}")
|
|
||||||
|
|
||||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
|
||||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
|
||||||
tokenizer.save_pretrained(local_tokenizer_path)
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_accelerator(args: argparse.Namespace):
|
def prepare_accelerator(args: argparse.Namespace):
|
||||||
"""
|
"""
|
||||||
this function also prepares deepspeed plugin
|
this function also prepares deepspeed plugin
|
||||||
@@ -5550,6 +5534,7 @@ def sample_images_common(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||||
|
TODO Use strategies here
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if steps == 0:
|
if steps == 0:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from library import sd3_models, sd3_utils
|
from library import sd3_models, sd3_utils, strategy_sd3
|
||||||
|
|
||||||
|
|
||||||
def get_noise(seed, latent):
|
def get_noise(seed, latent):
|
||||||
@@ -145,6 +145,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--clip_g", type=str, required=False)
|
parser.add_argument("--clip_g", type=str, required=False)
|
||||||
parser.add_argument("--clip_l", type=str, required=False)
|
parser.add_argument("--clip_l", type=str, required=False)
|
||||||
parser.add_argument("--t5xxl", type=str, required=False)
|
parser.add_argument("--t5xxl", type=str, required=False)
|
||||||
|
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
|
||||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||||
parser.add_argument("--negative_prompt", type=str, default="")
|
parser.add_argument("--negative_prompt", type=str, default="")
|
||||||
@@ -247,7 +248,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# load tokenizers
|
# load tokenizers
|
||||||
logger.info("Loading tokenizers...")
|
logger.info("Loading tokenizers...")
|
||||||
tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer
|
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
# logger.info("Create MMDiT from SD3 checkpoint...")
|
# logger.info("Create MMDiT from SD3 checkpoint...")
|
||||||
@@ -320,12 +321,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# prepare embeddings
|
# prepare embeddings
|
||||||
logger.info("Encoding prompts...")
|
logger.info("Encoding prompts...")
|
||||||
# embeds, pooled_embed
|
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||||
lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
|
|
||||||
cond = torch.cat([lg_out, t5_out], dim=-2), pooled
|
|
||||||
|
|
||||||
lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl)
|
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
|
||||||
neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled
|
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
||||||
|
)
|
||||||
|
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||||
|
|
||||||
|
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
|
||||||
|
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
||||||
|
)
|
||||||
|
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||||
|
|
||||||
# generate image
|
# generate image
|
||||||
logger.info("Generating image...")
|
logger.info("Generating image...")
|
||||||
|
|||||||
270
sd3_train.py
270
sd3_train.py
@@ -17,7 +17,7 @@ init_ipex()
|
|||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils
|
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3
|
||||||
from library.sdxl_train_util import match_mixed_precision
|
from library.sdxl_train_util import match_mixed_precision
|
||||||
|
|
||||||
# , sdxl_model_util
|
# , sdxl_model_util
|
||||||
@@ -69,10 +69,22 @@ def train(args):
|
|||||||
# not args.train_text_encoder
|
# not args.train_text_encoder
|
||||||
# ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません"
|
# ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません"
|
||||||
|
|
||||||
# training without text encoder cache is not supported
|
# # training without text encoder cache is not supported: because T5XXL must be cached
|
||||||
assert (
|
# assert (
|
||||||
args.cache_text_encoder_outputs
|
# args.cache_text_encoder_outputs
|
||||||
), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません"
|
# ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません"
|
||||||
|
|
||||||
|
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
|
||||||
|
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
|
||||||
|
+ " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs:
|
||||||
|
logger.warning(
|
||||||
|
"use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled."
|
||||||
|
+ " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります"
|
||||||
|
)
|
||||||
|
args.cache_text_encoder_outputs = True
|
||||||
|
|
||||||
# if args.block_lr:
|
# if args.block_lr:
|
||||||
# block_lrs = [float(lr) for lr in args.block_lr.split(",")]
|
# block_lrs = [float(lr) for lr in args.block_lr.split(",")]
|
||||||
@@ -88,17 +100,17 @@ def train(args):
|
|||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
# load tokenizer
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
sd3_tokenizer = sd3_models.SD3Tokenizer()
|
if args.cache_latents:
|
||||||
|
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
|
||||||
# prepare caching strategy
|
|
||||||
if args.new_caching:
|
|
||||||
latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
|
|
||||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||||
)
|
)
|
||||||
else:
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
latents_caching_strategy = None
|
|
||||||
train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
# load tokenizer and prepare tokenize strategy
|
||||||
|
sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length)
|
||||||
|
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
@@ -153,6 +165,16 @@ def train(args):
|
|||||||
train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認
|
train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||||
|
strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
|
||||||
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
args.text_encoder_batch_size,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
train_util.debug_dataset(train_dataset_group, True)
|
train_util.debug_dataset(train_dataset_group, True)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
@@ -215,19 +237,8 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
|
|
||||||
if not args.new_caching:
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
|
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(
|
|
||||||
vae_wrapper,
|
|
||||||
args.vae_batch_size,
|
|
||||||
args.cache_latents_to_disk,
|
|
||||||
accelerator.is_main_process,
|
|
||||||
file_suffix="_sd3.npz",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
latents_caching_strategy.set_vae(vae)
|
|
||||||
train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy)
|
|
||||||
vae.to("cpu") # if no sampling, vae can be deleted
|
vae.to("cpu") # if no sampling, vae can be deleted
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -246,60 +257,70 @@ def train(args):
|
|||||||
t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load)
|
t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load)
|
||||||
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# should be deleted after caching text encoder outputs when not training text encoder
|
||||||
|
# this strategy should not be used other than this process
|
||||||
|
text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
train_clip_l = False
|
train_clip_l = False
|
||||||
train_clip_g = False
|
train_clip_g = False
|
||||||
train_t5xxl = False
|
train_t5xxl = False
|
||||||
|
|
||||||
# if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
# # TODO each option for two text encoders?
|
accelerator.print("enable text encoder training")
|
||||||
# accelerator.print("enable text encoder training")
|
if args.gradient_checkpointing:
|
||||||
# if args.gradient_checkpointing:
|
clip_l.gradient_checkpointing_enable()
|
||||||
# text_encoder1.gradient_checkpointing_enable()
|
clip_g.gradient_checkpointing_enable()
|
||||||
# text_encoder2.gradient_checkpointing_enable()
|
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||||
# lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
||||||
# lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
train_clip_l = lr_te1 != 0
|
||||||
# train_clip_l = lr_te1 != 0
|
train_clip_g = lr_te2 != 0
|
||||||
# train_clip_g = lr_te2 != 0
|
|
||||||
|
if not train_clip_l:
|
||||||
|
clip_l.to(weight_dtype)
|
||||||
|
if not train_clip_g:
|
||||||
|
clip_g.to(weight_dtype)
|
||||||
|
clip_l.requires_grad_(train_clip_l)
|
||||||
|
clip_g.requires_grad_(train_clip_g)
|
||||||
|
clip_l.train(train_clip_l)
|
||||||
|
clip_g.train(train_clip_g)
|
||||||
|
else:
|
||||||
|
clip_l.to(weight_dtype)
|
||||||
|
clip_g.to(weight_dtype)
|
||||||
|
clip_l.requires_grad_(False)
|
||||||
|
clip_g.requires_grad_(False)
|
||||||
|
clip_l.eval()
|
||||||
|
clip_g.eval()
|
||||||
|
|
||||||
# # caching one text encoder output is not supported
|
|
||||||
# if not train_clip_l:
|
|
||||||
# text_encoder1.to(weight_dtype)
|
|
||||||
# if not train_clip_g:
|
|
||||||
# text_encoder2.to(weight_dtype)
|
|
||||||
# text_encoder1.requires_grad_(train_clip_l)
|
|
||||||
# text_encoder2.requires_grad_(train_clip_g)
|
|
||||||
# text_encoder1.train(train_clip_l)
|
|
||||||
# text_encoder2.train(train_clip_g)
|
|
||||||
# else:
|
|
||||||
clip_l.to(weight_dtype)
|
|
||||||
clip_g.to(weight_dtype)
|
|
||||||
clip_l.requires_grad_(False)
|
|
||||||
clip_g.requires_grad_(False)
|
|
||||||
clip_l.eval()
|
|
||||||
clip_g.eval()
|
|
||||||
if t5xxl is not None:
|
if t5xxl is not None:
|
||||||
t5xxl.to(t5xxl_dtype)
|
t5xxl.to(t5xxl_dtype)
|
||||||
t5xxl.requires_grad_(False)
|
t5xxl.requires_grad_(False)
|
||||||
t5xxl.eval()
|
t5xxl.eval()
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュする
|
# cache text encoder outputs
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# Text Encodes are eval and no grad
|
# Text Encodes are eval and no grad here
|
||||||
|
clip_l.to(accelerator.device)
|
||||||
|
clip_g.to(accelerator.device)
|
||||||
|
if t5xxl is not None:
|
||||||
|
t5xxl.to(t5xxl_device)
|
||||||
|
|
||||||
with torch.no_grad(), accelerator.autocast():
|
text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
|
||||||
train_dataset_group.cache_text_encoder_outputs_sd3(
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
sd3_tokenizer,
|
args.text_encoder_batch_size,
|
||||||
(clip_l, clip_g, t5xxl),
|
False,
|
||||||
(accelerator.device, accelerator.device, t5xxl_device),
|
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
|
||||||
None,
|
)
|
||||||
(None, None, None),
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||||
args.cache_text_encoder_outputs_to_disk,
|
|
||||||
accelerator.is_main_process,
|
|
||||||
args.text_encoder_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO we can delete text encoders after caching
|
clip_l.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
clip_g.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
if t5xxl is not None:
|
||||||
|
t5xxl.to(t5xxl_device, dtype=t5xxl_dtype)
|
||||||
|
|
||||||
|
with accelerator.autocast():
|
||||||
|
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# load MMDIT
|
# load MMDIT
|
||||||
@@ -332,11 +353,11 @@ def train(args):
|
|||||||
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
|
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
|
||||||
|
|
||||||
# if train_clip_l:
|
# if train_clip_l:
|
||||||
# training_models.append(text_encoder1)
|
# training_models.append(clip_l)
|
||||||
# params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
# params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||||
# if train_clip_g:
|
# if train_clip_g:
|
||||||
# training_models.append(text_encoder2)
|
# training_models.append(clip_g)
|
||||||
# params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
# params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
# calculate number of trainable parameters
|
||||||
n_params = 0
|
n_params = 0
|
||||||
@@ -344,7 +365,7 @@ def train(args):
|
|||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
n_params += p.numel()
|
n_params += p.numel()
|
||||||
|
|
||||||
accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}")
|
accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}")
|
||||||
accelerator.print(f"number of models: {len(training_models)}")
|
accelerator.print(f"number of models: {len(training_models)}")
|
||||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||||
|
|
||||||
@@ -398,7 +419,11 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
@@ -455,8 +480,8 @@ def train(args):
|
|||||||
# TODO check if this is necessary. SD3 uses pool for clip_l and clip_g
|
# TODO check if this is necessary. SD3 uses pool for clip_l and clip_g
|
||||||
# # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
# # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||||
# if train_clip_l:
|
# if train_clip_l:
|
||||||
# text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
# clip_l.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
# text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
# clip_l.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
|
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
@@ -484,9 +509,8 @@ def train(args):
|
|||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
args,
|
args,
|
||||||
mmdit=mmdit,
|
mmdit=mmdit,
|
||||||
# mmdie=mmdit if train_mmdit else None,
|
clip_l=clip_l if train_clip_l else None,
|
||||||
# text_encoder1=text_encoder1 if train_clip_l else None,
|
clip_g=clip_g if train_clip_g else None,
|
||||||
# text_encoder2=text_encoder2 if train_clip_g else None,
|
|
||||||
)
|
)
|
||||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -498,10 +522,10 @@ def train(args):
|
|||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if train_mmdit:
|
if train_mmdit:
|
||||||
mmdit = accelerator.prepare(mmdit)
|
mmdit = accelerator.prepare(mmdit)
|
||||||
# if train_clip_l:
|
if train_clip_l:
|
||||||
# text_encoder1 = accelerator.prepare(text_encoder1)
|
clip_l = accelerator.prepare(clip_l)
|
||||||
# if train_clip_g:
|
if train_clip_g:
|
||||||
# text_encoder2 = accelerator.prepare(text_encoder2)
|
clip_g = accelerator.prepare(clip_g)
|
||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
@@ -613,7 +637,7 @@ def train(args):
|
|||||||
|
|
||||||
# # For --sample_at_first
|
# # For --sample_at_first
|
||||||
# sd3_train_utils.sample_images(
|
# sd3_train_utils.sample_images(
|
||||||
# accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit
|
# accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# following function will be moved to sd3_train_utils
|
# following function will be moved to sd3_train_utils
|
||||||
@@ -666,6 +690,7 @@ def train(args):
|
|||||||
return weighting
|
return weighting
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
|
epoch = 0 # avoid error when max_train_steps is 0
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
@@ -687,37 +712,45 @@ def train(args):
|
|||||||
# encode images to latents. images are [-1, 1]
|
# encode images to latents. images are [-1, 1]
|
||||||
latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype)
|
latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype)
|
||||||
|
|
||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
|
|
||||||
# latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
# latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
latents = sd3_models.SDVAE.process_in(latents)
|
latents = sd3_models.SDVAE.process_in(latents)
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
# not cached, get text encoder outputs
|
if text_encoder_outputs_list is not None:
|
||||||
# XXX This does not work yet
|
lg_out, t5_out, lg_pooled = text_encoder_outputs_list
|
||||||
input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"]
|
if args.use_t5xxl_cache_only:
|
||||||
|
lg_out = None
|
||||||
|
lg_pooled = None
|
||||||
|
else:
|
||||||
|
lg_out = None
|
||||||
|
t5_out = None
|
||||||
|
lg_pooled = None
|
||||||
|
|
||||||
|
if lg_out is None or (train_clip_l or train_clip_g):
|
||||||
|
# not cached or training, so get from text encoders
|
||||||
|
input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"]
|
||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# TODO support weighted captions
|
# TODO support weighted captions
|
||||||
# TODO support length > 75
|
|
||||||
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
|
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
|
||||||
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
|
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
|
||||||
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device)
|
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
|
||||||
|
sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None]
|
||||||
# get text encoder outputs: outputs are concatenated
|
|
||||||
context, pool = sd3_utils.get_cond_from_tokens(
|
|
||||||
input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
if t5_out is None:
|
||||||
# encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
_, _, input_ids_t5xxl = batch["input_ids_list"]
|
||||||
# pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
with torch.no_grad():
|
||||||
# TODO this reuses SDXL keys, it should be fixed
|
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
|
||||||
lg_out = batch["text_encoder_outputs1_list"]
|
_, t5_out, _ = text_encoding_strategy.encode_tokens(
|
||||||
t5_out = batch["text_encoder_outputs2_list"]
|
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl]
|
||||||
pool = batch["text_encoder_pool2_list"]
|
)
|
||||||
context = torch.cat([lg_out, t5_out], dim=-2)
|
|
||||||
|
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
|
||||||
|
|
||||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||||
|
|
||||||
@@ -748,13 +781,13 @@ def train(args):
|
|||||||
if torch.any(torch.isnan(context)):
|
if torch.any(torch.isnan(context)):
|
||||||
accelerator.print("NaN found in context, replacing with zeros")
|
accelerator.print("NaN found in context, replacing with zeros")
|
||||||
context = torch.nan_to_num(context, 0, out=context)
|
context = torch.nan_to_num(context, 0, out=context)
|
||||||
if torch.any(torch.isnan(pool)):
|
if torch.any(torch.isnan(lg_pooled)):
|
||||||
accelerator.print("NaN found in pool, replacing with zeros")
|
accelerator.print("NaN found in pool, replacing with zeros")
|
||||||
pool = torch.nan_to_num(pool, 0, out=pool)
|
lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled)
|
||||||
|
|
||||||
# call model
|
# call model
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool)
|
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||||
|
|
||||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||||
# Preconditioning of the model outputs.
|
# Preconditioning of the model outputs.
|
||||||
@@ -806,7 +839,7 @@ def train(args):
|
|||||||
# accelerator.device,
|
# accelerator.device,
|
||||||
# vae,
|
# vae,
|
||||||
# [tokenizer1, tokenizer2],
|
# [tokenizer1, tokenizer2],
|
||||||
# [text_encoder1, text_encoder2],
|
# [clip_l, clip_g],
|
||||||
# mmdit,
|
# mmdit,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
@@ -875,7 +908,7 @@ def train(args):
|
|||||||
# accelerator.device,
|
# accelerator.device,
|
||||||
# vae,
|
# vae,
|
||||||
# [tokenizer1, tokenizer2],
|
# [tokenizer1, tokenizer2],
|
||||||
# [text_encoder1, text_encoder2],
|
# [clip_l, clip_g],
|
||||||
# mmdit,
|
# mmdit,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
@@ -924,7 +957,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
sd3_train_utils.add_sd3_training_arguments(parser)
|
sd3_train_utils.add_sd3_training_arguments(parser)
|
||||||
|
|
||||||
# parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
parser.add_argument(
|
||||||
|
"--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する"
|
||||||
|
)
|
||||||
|
# parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--t5xxl_max_token_length",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
|
||||||
|
)
|
||||||
|
|
||||||
# TE training is disabled temporarily
|
# TE training is disabled temporarily
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
@@ -962,7 +1007,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip_latents_validity_check",
|
"--skip_latents_validity_check",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
108
sdxl_train.py
108
sdxl_train.py
@@ -17,7 +17,7 @@ init_ipex()
|
|||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import deepspeed_utils, sdxl_model_util
|
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
@@ -124,7 +124,16 @@ def train(args):
|
|||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future
|
||||||
|
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
if args.cache_latents:
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
@@ -166,10 +175,10 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -262,8 +271,9 @@ def train(args):
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
|
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -276,6 +286,9 @@ def train(args):
|
|||||||
train_text_encoder1 = False
|
train_text_encoder1 = False
|
||||||
train_text_encoder2 = False
|
train_text_encoder2 = False
|
||||||
|
|
||||||
|
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
# TODO each option for two text encoders?
|
# TODO each option for two text encoders?
|
||||||
accelerator.print("enable text encoder training")
|
accelerator.print("enable text encoder training")
|
||||||
@@ -307,16 +320,17 @@ def train(args):
|
|||||||
# TextEncoderの出力をキャッシュする
|
# TextEncoderの出力をキャッシュする
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# Text Encodes are eval and no grad
|
# Text Encodes are eval and no grad
|
||||||
with torch.no_grad(), accelerator.autocast():
|
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||||
train_dataset_group.cache_text_encoder_outputs(
|
args.cache_text_encoder_outputs_to_disk, None, False
|
||||||
(tokenizer1, tokenizer2),
|
)
|
||||||
(text_encoder1, text_encoder2),
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||||
accelerator.device,
|
|
||||||
None,
|
text_encoder1.to(accelerator.device)
|
||||||
args.cache_text_encoder_outputs_to_disk,
|
text_encoder2.to(accelerator.device)
|
||||||
accelerator.is_main_process,
|
with accelerator.autocast():
|
||||||
)
|
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
@@ -403,7 +417,11 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
@@ -597,7 +615,7 @@ def train(args):
|
|||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
sdxl_train_util.sample_images(
|
sdxl_train_util.sample_images(
|
||||||
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
|
accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
@@ -628,9 +646,15 @@ def train(args):
|
|||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
input_ids1 = batch["input_ids"]
|
if text_encoder_outputs_list is not None:
|
||||||
input_ids2 = batch["input_ids2"]
|
# Text Encoder outputs are cached
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
|
||||||
|
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
else:
|
||||||
|
input_ids1, input_ids2 = batch["input_ids_list"]
|
||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
# TODO support weighted captions
|
# TODO support weighted captions
|
||||||
@@ -646,39 +670,13 @@ def train(args):
|
|||||||
# else:
|
# else:
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
# unwrap_model is fine for models not wrapped by accelerator
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
|
||||||
args.max_token_length,
|
|
||||||
input_ids1,
|
|
||||||
input_ids2,
|
|
||||||
tokenizer1,
|
|
||||||
tokenizer2,
|
|
||||||
text_encoder1,
|
|
||||||
text_encoder2,
|
|
||||||
None if not args.full_fp16 else weight_dtype,
|
|
||||||
accelerator=accelerator,
|
|
||||||
)
|
)
|
||||||
else:
|
if args.full_fp16:
|
||||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
|
||||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
|
||||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
pool2 = pool2.to(weight_dtype)
|
||||||
|
|
||||||
# # verify that the text encoder outputs are correct
|
|
||||||
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
|
||||||
# args.max_token_length,
|
|
||||||
# batch["input_ids"].to(text_encoder1.device),
|
|
||||||
# batch["input_ids2"].to(text_encoder1.device),
|
|
||||||
# tokenizer1,
|
|
||||||
# tokenizer2,
|
|
||||||
# text_encoder1,
|
|
||||||
# text_encoder2,
|
|
||||||
# None if not args.full_fp16 else weight_dtype,
|
|
||||||
# )
|
|
||||||
# b_size = encoder_hidden_states1.shape[0]
|
|
||||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
|
||||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
|
||||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
|
||||||
# logger.info("text encoder outputs verified")
|
|
||||||
|
|
||||||
# get size embeddings
|
# get size embeddings
|
||||||
orig_size = batch["original_sizes_hw"]
|
orig_size = batch["original_sizes_hw"]
|
||||||
@@ -765,7 +763,7 @@ def train(args):
|
|||||||
global_step,
|
global_step,
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
vae,
|
vae,
|
||||||
[tokenizer1, tokenizer2],
|
tokenizers,
|
||||||
[text_encoder1, text_encoder2],
|
[text_encoder1, text_encoder2],
|
||||||
unet,
|
unet,
|
||||||
)
|
)
|
||||||
@@ -847,7 +845,7 @@ def train(args):
|
|||||||
global_step,
|
global_step,
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
vae,
|
vae,
|
||||||
[tokenizer1, tokenizer2],
|
tokenizers,
|
||||||
[text_encoder1, text_encoder2],
|
[text_encoder1, text_encoder2],
|
||||||
unet,
|
unet,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,7 +23,16 @@ from accelerate.utils import set_seed
|
|||||||
import accelerate
|
import accelerate
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler, ControlNetModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
from library import (
|
||||||
|
deepspeed_utils,
|
||||||
|
sai_model_spec,
|
||||||
|
sdxl_model_util,
|
||||||
|
sdxl_original_unet,
|
||||||
|
sdxl_train_util,
|
||||||
|
strategy_base,
|
||||||
|
strategy_sd,
|
||||||
|
strategy_sdxl,
|
||||||
|
)
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
@@ -79,7 +88,14 @@ def train(args):
|
|||||||
args.seed = random.randint(0, 2**32)
|
args.seed = random.randint(0, 2**32)
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||||
@@ -106,7 +122,7 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
@@ -164,30 +180,30 @@ def train(args):
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
vae,
|
|
||||||
args.vae_batch_size,
|
|
||||||
args.cache_latents_to_disk,
|
|
||||||
accelerator.is_main_process,
|
|
||||||
)
|
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュする
|
# TextEncoderの出力をキャッシュする
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# Text Encodes are eval and no grad
|
# Text Encodes are eval and no grad
|
||||||
with torch.no_grad():
|
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||||
train_dataset_group.cache_text_encoder_outputs(
|
args.cache_text_encoder_outputs_to_disk, None, False
|
||||||
(tokenizer1, tokenizer2),
|
)
|
||||||
(text_encoder1, text_encoder2),
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||||
accelerator.device,
|
|
||||||
None,
|
text_encoder1.to(accelerator.device)
|
||||||
args.cache_text_encoder_outputs_to_disk,
|
text_encoder2.to(accelerator.device)
|
||||||
accelerator.is_main_process,
|
with accelerator.autocast():
|
||||||
)
|
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# prepare ControlNet-LLLite
|
# prepare ControlNet-LLLite
|
||||||
@@ -242,7 +258,11 @@ def train(args):
|
|||||||
|
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
|
|
||||||
@@ -290,7 +310,7 @@ def train(args):
|
|||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
if isinstance(unet, DDP):
|
if isinstance(unet, DDP):
|
||||||
unet._set_static_graph() # avoid error for multiple use of the parameter
|
unet._set_static_graph() # avoid error for multiple use of the parameter
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||||
@@ -357,7 +377,9 @@ def train(args):
|
|||||||
if args.log_tracker_config is not None:
|
if args.log_tracker_config is not None:
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers(
|
accelerator.init_trackers(
|
||||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||||
|
config=train_util.get_sanitized_config_or_none(args),
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
@@ -409,27 +431,26 @@ def train(args):
|
|||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
|
||||||
input_ids1 = batch["input_ids"]
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
input_ids2 = batch["input_ids2"]
|
if text_encoder_outputs_list is not None:
|
||||||
|
# Text Encoder outputs are cached
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
|
||||||
|
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
else:
|
||||||
|
input_ids1, input_ids2 = batch["input_ids_list"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Get the text embedding for conditioning
|
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
|
||||||
args.max_token_length,
|
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
|
||||||
input_ids1,
|
|
||||||
input_ids2,
|
|
||||||
tokenizer1,
|
|
||||||
tokenizer2,
|
|
||||||
text_encoder1,
|
|
||||||
text_encoder2,
|
|
||||||
None if not args.full_fp16 else weight_dtype,
|
|
||||||
)
|
)
|
||||||
else:
|
if args.full_fp16:
|
||||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
|
||||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
|
||||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
pool2 = pool2.to(weight_dtype)
|
||||||
|
|
||||||
# get size embeddings
|
# get size embeddings
|
||||||
orig_size = batch["original_sizes_hw"]
|
orig_size = batch["original_sizes_hw"]
|
||||||
|
|||||||
@@ -1,16 +1,21 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util
|
||||||
import train_network
|
import train_network
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -49,15 +54,32 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
def load_tokenizer(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def is_text_encoder_outputs_cached(self, args):
|
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
|
||||||
return args.cache_text_encoder_outputs
|
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
|
||||||
|
|
||||||
|
def get_latents_caching_strategy(self, args):
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
return latents_caching_strategy
|
||||||
|
|
||||||
|
def get_text_encoding_strategy(self, args):
|
||||||
|
return strategy_sdxl.SdxlTextEncodingStrategy()
|
||||||
|
|
||||||
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
|
return text_encoders + [accelerator.unwrap_model(text_encoders[-1])]
|
||||||
|
|
||||||
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def cache_text_encoder_outputs_if_needed(
|
def cache_text_encoder_outputs_if_needed(
|
||||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||||
):
|
):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
@@ -70,15 +92,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
dataset.cache_text_encoder_outputs(
|
dataset.new_cache_text_encoder_outputs(
|
||||||
tokenizers,
|
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
|
||||||
text_encoders,
|
|
||||||
accelerator.device,
|
|
||||||
weight_dtype,
|
|
||||||
args.cache_text_encoder_outputs_to_disk,
|
|
||||||
accelerator.is_main_process,
|
|
||||||
)
|
)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ import regex
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex
|
from library.device_utils import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util
|
||||||
|
|
||||||
import train_textual_inversion
|
import train_textual_inversion
|
||||||
|
|
||||||
|
|
||||||
@@ -41,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
|
|
||||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
def load_tokenizer(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
|
||||||
input_ids1 = batch["input_ids"]
|
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
|
||||||
input_ids2 = batch["input_ids2"]
|
|
||||||
with torch.enable_grad():
|
def get_latents_caching_strategy(self, args):
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
)
|
||||||
args.max_token_length,
|
return latents_caching_strategy
|
||||||
input_ids1,
|
|
||||||
input_ids2,
|
def get_text_encoding_strategy(self, args):
|
||||||
tokenizers[0],
|
return strategy_sdxl.SdxlTextEncodingStrategy()
|
||||||
tokenizers[1],
|
|
||||||
text_encoders[0],
|
|
||||||
text_encoders[1],
|
|
||||||
None if not args.full_fp16 else weight_dtype,
|
|
||||||
accelerator=accelerator,
|
|
||||||
)
|
|
||||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
|
||||||
|
|
||||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
@@ -81,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
def sample_images(
|
||||||
|
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
|
||||||
|
):
|
||||||
sdxl_train_util.sample_images(
|
sdxl_train_util.sample_images(
|
||||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||||
@@ -122,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = train_textual_inversion.setup_parser()
|
parser = train_textual_inversion.setup_parser()
|
||||||
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||||
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
67
train_db.py
67
train_db.py
@@ -11,7 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library import deepspeed_utils
|
from library import deepspeed_utils, strategy_base
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
|
|
||||||
@@ -38,6 +38,7 @@ from library.custom_train_functions import (
|
|||||||
apply_masked_loss,
|
apply_masked_loss,
|
||||||
)
|
)
|
||||||
from library.utils import setup_logging, add_logging_arguments
|
from library.utils import setup_logging, add_logging_arguments
|
||||||
|
import library.strategy_sd as strategy_sd
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -58,7 +59,14 @@ def train(args):
|
|||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
@@ -80,10 +88,10 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -145,13 +153,17 @@ def train(args):
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
|
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||||
unet.requires_grad_(True) # 念のため追加
|
unet.requires_grad_(True) # 念のため追加
|
||||||
@@ -184,8 +196,11 @@ def train(args):
|
|||||||
|
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
@@ -290,10 +305,16 @@ def train(args):
|
|||||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||||
if args.log_tracker_config is not None:
|
if args.log_tracker_config is not None:
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
|
accelerator.init_trackers(
|
||||||
|
"dreambooth" if args.log_tracker_name is None else args.log_tracker_name,
|
||||||
|
config=train_util.get_sanitized_config_or_none(args),
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(
|
||||||
|
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||||
|
)
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
@@ -331,7 +352,7 @@ def train(args):
|
|||||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
encoder_hidden_states = get_weighted_text_embeddings(
|
encoder_hidden_states = get_weighted_text_embeddings(
|
||||||
tokenizer,
|
tokenize_strategy.tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
batch["captions"],
|
batch["captions"],
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
@@ -339,14 +360,18 @@ def train(args):
|
|||||||
clip_skip=args.clip_skip,
|
clip_skip=args.clip_skip,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||||
encoder_hidden_states = train_util.get_hidden_states(
|
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
||||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
tokenize_strategy, [text_encoder], [input_ids]
|
||||||
)
|
)[0]
|
||||||
|
if args.full_fp16:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -358,7 +383,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -393,7 +420,7 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
train_util.sample_images(
|
train_util.sample_images(
|
||||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||||
)
|
)
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
@@ -457,7 +484,9 @@ def train(args):
|
|||||||
vae,
|
vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(
|
||||||
|
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||||
|
)
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
122
train_network.py
122
train_network.py
@@ -7,6 +7,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
from typing import Any, List
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -18,7 +19,7 @@ init_ipex()
|
|||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import deepspeed_utils, model_util
|
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import DreamBoothDataset
|
from library.train_util import DreamBoothDataset
|
||||||
@@ -101,19 +102,31 @@ class NetworkTrainer:
|
|||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||||
|
|
||||||
def load_tokenizer(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def is_text_encoder_outputs_cached(self, args):
|
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
|
||||||
return False
|
return [tokenize_strategy.tokenizer]
|
||||||
|
|
||||||
|
def get_latents_caching_strategy(self, args):
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
return latents_caching_strategy
|
||||||
|
|
||||||
|
def get_text_encoding_strategy(self, args):
|
||||||
|
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||||
|
|
||||||
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
|
return text_encoders
|
||||||
|
|
||||||
def is_train_text_encoder(self, args):
|
def is_train_text_encoder(self, args):
|
||||||
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
return not args.network_train_unet_only
|
||||||
|
|
||||||
def cache_text_encoder_outputs_if_needed(
|
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
|
||||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
|
||||||
):
|
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
@@ -123,7 +136,7 @@ class NetworkTrainer:
|
|||||||
return encoder_hidden_states
|
return encoder_hidden_states
|
||||||
|
|
||||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def all_reduce_network(self, accelerator, network):
|
def all_reduce_network(self, accelerator, network):
|
||||||
@@ -131,8 +144,8 @@ class NetworkTrainer:
|
|||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
||||||
|
|
||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
|
||||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
@@ -150,9 +163,13 @@ class NetworkTrainer:
|
|||||||
args.seed = random.randint(0, 2**32)
|
args.seed = random.randint(0, 2**32)
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
# tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
|
tokenize_strategy = self.get_tokenize_strategy(args)
|
||||||
tokenizer = self.load_tokenizer(args)
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
|
||||||
|
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
latents_caching_strategy = self.get_latents_caching_strategy(args)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
@@ -194,11 +211,11 @@ class NetworkTrainer:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
# use arbitrary dataset class
|
# use arbitrary dataset class
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -268,8 +285,9 @@ class NetworkTrainer:
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
|
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -277,9 +295,13 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||||
self.cache_text_encoder_outputs_if_needed(
|
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
||||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
)
|
|
||||||
|
text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
|
||||||
|
if text_encoder_outputs_caching_strategy is not None:
|
||||||
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
||||||
|
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
||||||
|
|
||||||
# prepare network
|
# prepare network
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
@@ -366,7 +388,11 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
|
|
||||||
@@ -878,7 +904,7 @@ class NetworkTrainer:
|
|||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||||
@@ -933,21 +959,31 @@ class NetworkTrainer:
|
|||||||
# print(f"set multiplier: {multipliers}")
|
# print(f"set multiplier: {multipliers}")
|
||||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||||
|
|
||||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
# Get the text embedding for conditioning
|
if text_encoder_outputs_list is not None:
|
||||||
if args.weighted_captions:
|
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
else:
|
||||||
tokenizer,
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
text_encoder,
|
# Get the text embedding for conditioning
|
||||||
batch["captions"],
|
if args.weighted_captions:
|
||||||
accelerator.device,
|
# SD only
|
||||||
args.max_token_length // 75 if args.max_token_length else 1,
|
text_encoder_conds = get_weighted_text_embeddings(
|
||||||
clip_skip=args.clip_skip,
|
tokenizers[0],
|
||||||
)
|
text_encoder,
|
||||||
else:
|
batch["captions"],
|
||||||
text_encoder_conds = self.get_text_cond(
|
accelerator.device,
|
||||||
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
args.max_token_length // 75 if args.max_token_length else 1,
|
||||||
)
|
clip_skip=args.clip_skip,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||||
|
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
if args.full_fp16:
|
||||||
|
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
@@ -1026,7 +1062,9 @@ class NetworkTrainer:
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
self.sample_images(
|
||||||
|
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||||
|
)
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
@@ -1082,7 +1120,7 @@ class NetworkTrainer:
|
|||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||||
|
|
||||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
from typing import Any, List
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -15,7 +16,7 @@ init_ipex()
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from library import deepspeed_utils, model_util
|
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
@@ -103,28 +104,38 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
def load_target_model(self, args, weight_dtype, accelerator):
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
|
||||||
|
|
||||||
def load_tokenizer(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||||
return tokenizer
|
|
||||||
|
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
|
||||||
|
return [tokenize_strategy.tokenizer]
|
||||||
|
|
||||||
|
def get_latents_caching_strategy(self, args):
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
return latents_caching_strategy
|
||||||
|
|
||||||
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
|
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
def get_text_encoding_strategy(self, args):
|
||||||
with torch.enable_grad():
|
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
|
||||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]:
|
||||||
return encoder_hidden_states
|
return text_encoders
|
||||||
|
|
||||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
def sample_images(
|
||||||
|
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
|
||||||
|
):
|
||||||
train_util.sample_images(
|
train_util.sample_images(
|
||||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||||
@@ -182,8 +193,13 @@ class TextualInversionTrainer:
|
|||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
|
tokenize_strategy = self.get_tokenize_strategy(args)
|
||||||
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
|
||||||
|
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
latents_caching_strategy = self.get_latents_caching_strategy(args)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
logger.info("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
@@ -194,14 +210,7 @@ class TextualInversionTrainer:
|
|||||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||||
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
|
|
||||||
|
|
||||||
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
|
|
||||||
accelerator.print(
|
|
||||||
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
|
|
||||||
+ "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert the init_word to token_id
|
# Convert the init_word to token_id
|
||||||
init_token_ids_list = []
|
init_token_ids_list = []
|
||||||
@@ -310,10 +319,10 @@ class TextualInversionTrainer:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
self.assert_extra_args(args, train_dataset_group)
|
self.assert_extra_args(args, train_dataset_group)
|
||||||
|
|
||||||
@@ -368,11 +377,10 @@ class TextualInversionTrainer:
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
|
||||||
vae.to("cpu")
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
|
|
||||||
|
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||||
|
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@@ -387,7 +395,11 @@ class TextualInversionTrainer:
|
|||||||
trainable_params += text_encoder.get_input_embeddings().parameters()
|
trainable_params += text_encoder.get_input_embeddings().parameters()
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# prepare dataloader
|
||||||
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
|
# some strategies can be None
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
@@ -415,20 +427,8 @@ class TextualInversionTrainer:
|
|||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if len(text_encoders) == 1:
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
|
||||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
|
|
||||||
elif len(text_encoders) == 2:
|
|
||||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
||||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
|
|
||||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
index_no_updates_list = []
|
index_no_updates_list = []
|
||||||
orig_embeds_params_list = []
|
orig_embeds_params_list = []
|
||||||
@@ -456,6 +456,9 @@ class TextualInversionTrainer:
|
|||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
|
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -510,7 +513,9 @@ class TextualInversionTrainer:
|
|||||||
if args.log_tracker_config is not None:
|
if args.log_tracker_config is not None:
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers(
|
accelerator.init_trackers(
|
||||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
|
||||||
|
config=train_util.get_sanitized_config_or_none(args),
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -540,8 +545,8 @@ class TextualInversionTrainer:
|
|||||||
global_step,
|
global_step,
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
vae,
|
vae,
|
||||||
tokenizer_or_list,
|
tokenizers,
|
||||||
text_encoder_or_list,
|
text_encoders,
|
||||||
unet,
|
unet,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
)
|
)
|
||||||
@@ -568,7 +573,12 @@ class TextualInversionTrainer:
|
|||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
|
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||||
|
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids
|
||||||
|
)
|
||||||
|
if args.full_fp16:
|
||||||
|
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
@@ -588,7 +598,9 @@ class TextualInversionTrainer:
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -639,8 +651,8 @@ class TextualInversionTrainer:
|
|||||||
global_step,
|
global_step,
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
vae,
|
vae,
|
||||||
tokenizer_or_list,
|
tokenizers,
|
||||||
text_encoder_or_list,
|
text_encoders,
|
||||||
unet,
|
unet,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
)
|
)
|
||||||
@@ -722,8 +734,8 @@ class TextualInversionTrainer:
|
|||||||
global_step,
|
global_step,
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
vae,
|
vae,
|
||||||
tokenizer_or_list,
|
tokenizers,
|
||||||
text_encoder_or_list,
|
text_encoders,
|
||||||
unet,
|
unet,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user