mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
54
fine_tune.py
54
fine_tune.py
@@ -10,7 +10,7 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library import deepspeed_utils, strategy_base
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -39,6 +39,7 @@ from library.custom_train_functions import (
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import library.strategy_sd as strategy_sd
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -52,7 +53,15 @@ def train(args):
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if cache_latents:
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
@@ -81,10 +90,10 @@ def train(args):
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -165,8 +174,9 @@ def train(args):
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -192,6 +202,9 @@ def train(args):
|
||||
else:
|
||||
text_encoder.eval()
|
||||
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
@@ -214,7 +227,11 @@ def train(args):
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -317,7 +334,9 @@ def train(args):
|
||||
)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -342,8 +361,9 @@ def train(args):
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# TODO move to strategy_sd.py
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
tokenize_strategy.tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
@@ -351,10 +371,12 @@ def train(args):
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [text_encoder], [input_ids]
|
||||
)[0]
|
||||
if args.full_fp16:
|
||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -409,7 +431,7 @@ def train(args):
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -472,7 +494,9 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user