Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -2,6 +2,7 @@ import argparse
import math
import os
from multiprocessing import Value
from typing import Any, List
import toml
from tqdm import tqdm
@@ -15,7 +16,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import deepspeed_utils, model_util
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -103,28 +104,38 @@ class TextualInversionTrainer:
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
return latents_caching_strategy
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
pass
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
with torch.enable_grad():
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
return encoder_hidden_states
def get_text_encoding_strategy(self, args):
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]:
return text_encoders
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -182,8 +193,13 @@ class TextualInversionTrainer:
if args.seed is not None:
set_seed(args.seed)
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
tokenize_strategy = self.get_tokenize_strategy(args)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = self.get_latents_caching_strategy(args)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -194,14 +210,7 @@ class TextualInversionTrainer:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
accelerator.print(
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
+ "accelerateでは複数のモデルテキストエンコーダーのgradient_accumulation_stepsはサポートされていないようです"
)
model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
# Convert the init_word to token_id
init_token_ids_list = []
@@ -310,10 +319,10 @@ class TextualInversionTrainer:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
train_dataset_group = train_util.load_arbitrary_dataset(args)
self.assert_extra_args(args, train_dataset_group)
@@ -368,11 +377,10 @@ class TextualInversionTrainer:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -387,7 +395,11 @@ class TextualInversionTrainer:
trainable_params += text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -415,20 +427,8 @@ class TextualInversionTrainer:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
if len(text_encoders) == 1:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
index_no_updates_list = []
orig_embeds_params_list = []
@@ -456,6 +456,9 @@ class TextualInversionTrainer:
else:
unet.eval()
text_encoding_strategy = self.get_text_encoding_strategy(args)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
@@ -510,7 +513,9 @@ class TextualInversionTrainer:
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# function for saving/removing
@@ -540,8 +545,8 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
tokenizers,
text_encoders,
unet,
prompt_replacement,
)
@@ -568,7 +573,12 @@ class TextualInversionTrainer:
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -588,7 +598,9 @@ class TextualInversionTrainer:
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -639,8 +651,8 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
tokenizers,
text_encoders,
unet,
prompt_replacement,
)
@@ -722,8 +734,8 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
tokenizers,
text_encoders,
unet,
prompt_replacement,
)