Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -10,7 +10,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -39,6 +39,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd
def train(args):
@@ -52,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
@@ -81,10 +90,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -165,8 +174,9 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -192,6 +202,9 @@ def train(args):
else:
text_encoder.eval()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
@@ -214,7 +227,11 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -317,7 +334,9 @@ def train(args):
)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -342,8 +361,9 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
tokenize_strategy.tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
@@ -351,10 +371,12 @@ def train(args):
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -409,7 +431,7 @@ def train(args):
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
@@ -472,7 +494,9 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
is_main_process = accelerator.is_main_process
if is_main_process: