Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -5,10 +5,10 @@ import regex
import torch
from library.device_utils import init_ipex
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util
import train_textual_inversion
@@ -41,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def get_tokenize_strategy(self, args):
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sdxl.SdxlTextEncodingStrategy()
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -81,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -122,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
return parser