mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
@@ -2,6 +2,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -15,7 +16,7 @@ init_ipex()
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
from library import deepspeed_utils, model_util
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
@@ -103,28 +104,38 @@ class TextualInversionTrainer:
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
return tokenizer
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
|
||||
pass
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
with torch.enable_grad():
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
|
||||
return encoder_hidden_states
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]:
|
||||
return text_encoders
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
def sample_images(
|
||||
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
|
||||
):
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||
@@ -182,8 +193,13 @@ class TextualInversionTrainer:
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
|
||||
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
||||
tokenize_strategy = self.get_tokenize_strategy(args)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = self.get_latents_caching_strategy(args)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
@@ -194,14 +210,7 @@ class TextualInversionTrainer:
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
|
||||
|
||||
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
|
||||
accelerator.print(
|
||||
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
|
||||
+ "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです"
|
||||
)
|
||||
model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
init_token_ids_list = []
|
||||
@@ -310,10 +319,10 @@ class TextualInversionTrainer:
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
@@ -368,11 +377,10 @@ class TextualInversionTrainer:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -387,7 +395,11 @@ class TextualInversionTrainer:
|
||||
trainable_params += text_encoder.get_input_embeddings().parameters()
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -415,20 +427,8 @@ class TextualInversionTrainer:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if len(text_encoders) == 1:
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
elif len(text_encoders) == 2:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
|
||||
|
||||
index_no_updates_list = []
|
||||
orig_embeds_params_list = []
|
||||
@@ -456,6 +456,9 @@ class TextualInversionTrainer:
|
||||
else:
|
||||
unet.eval()
|
||||
|
||||
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
@@ -510,7 +513,9 @@ class TextualInversionTrainer:
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
# function for saving/removing
|
||||
@@ -540,8 +545,8 @@ class TextualInversionTrainer:
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
@@ -568,7 +573,12 @@ class TextualInversionTrainer:
|
||||
latents = latents * self.vae_scale_factor
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -588,7 +598,9 @@ class TextualInversionTrainer:
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -639,8 +651,8 @@ class TextualInversionTrainer:
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
@@ -722,8 +734,8 @@ class TextualInversionTrainer:
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user