mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
211 lines
9.7 KiB
Python
211 lines
9.7 KiB
Python
import argparse
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from library.device_utils import init_ipex, clean_memory_on_device
|
|
|
|
init_ipex()
|
|
|
|
from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util
|
|
import train_network
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
|
self.is_sdxl = True
|
|
|
|
def assert_extra_args(self, args, train_dataset_group):
|
|
super().assert_extra_args(args, train_dataset_group)
|
|
sdxl_train_util.verify_sdxl_training_args(args)
|
|
|
|
if args.cache_text_encoder_outputs:
|
|
assert (
|
|
train_dataset_group.is_text_encoder_output_cacheable()
|
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
|
|
|
assert (
|
|
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
|
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
|
|
|
train_dataset_group.verify_bucket_reso_steps(32)
|
|
|
|
def load_target_model(self, args, weight_dtype, accelerator):
|
|
(
|
|
load_stable_diffusion_format,
|
|
text_encoder1,
|
|
text_encoder2,
|
|
vae,
|
|
unet,
|
|
logit_scale,
|
|
ckpt_info,
|
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
|
|
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
|
self.logit_scale = logit_scale
|
|
self.ckpt_info = ckpt_info
|
|
|
|
# モデルに xformers とか memory efficient attention を組み込む
|
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
|
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
|
|
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
|
|
|
def get_tokenize_strategy(self, args):
|
|
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
|
|
|
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
|
|
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
|
|
|
|
def get_latents_caching_strategy(self, args):
|
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
|
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
|
)
|
|
return latents_caching_strategy
|
|
|
|
def get_text_encoding_strategy(self, args):
|
|
return strategy_sdxl.SdxlTextEncodingStrategy()
|
|
|
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
|
return text_encoders + [accelerator.unwrap_model(text_encoders[-1])]
|
|
|
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
|
if args.cache_text_encoder_outputs:
|
|
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
|
else:
|
|
return None
|
|
|
|
def cache_text_encoder_outputs_if_needed(
|
|
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
|
):
|
|
if args.cache_text_encoder_outputs:
|
|
if not args.lowram:
|
|
# メモリ消費を減らす
|
|
logger.info("move vae and unet to cpu to save memory")
|
|
org_vae_device = vae.device
|
|
org_unet_device = unet.device
|
|
vae.to("cpu")
|
|
unet.to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
|
with accelerator.autocast():
|
|
dataset.new_cache_text_encoder_outputs(
|
|
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
|
|
)
|
|
accelerator.wait_for_everyone()
|
|
|
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
|
text_encoders[1].to("cpu", dtype=torch.float32)
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
if not args.lowram:
|
|
logger.info("move vae and unet back to original device")
|
|
vae.to(org_vae_device)
|
|
unet.to(org_unet_device)
|
|
else:
|
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
|
|
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
|
input_ids1 = batch["input_ids"]
|
|
input_ids2 = batch["input_ids2"]
|
|
with torch.enable_grad():
|
|
# Get the text embedding for conditioning
|
|
# TODO support weighted captions
|
|
# if args.weighted_captions:
|
|
# encoder_hidden_states = get_weighted_text_embeddings(
|
|
# tokenizer,
|
|
# text_encoder,
|
|
# batch["captions"],
|
|
# accelerator.device,
|
|
# args.max_token_length // 75 if args.max_token_length else 1,
|
|
# clip_skip=args.clip_skip,
|
|
# )
|
|
# else:
|
|
input_ids1 = input_ids1.to(accelerator.device)
|
|
input_ids2 = input_ids2.to(accelerator.device)
|
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
|
args.max_token_length,
|
|
input_ids1,
|
|
input_ids2,
|
|
tokenizers[0],
|
|
tokenizers[1],
|
|
text_encoders[0],
|
|
text_encoders[1],
|
|
None if not args.full_fp16 else weight_dtype,
|
|
accelerator=accelerator,
|
|
)
|
|
else:
|
|
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
|
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
|
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
|
|
|
# # verify that the text encoder outputs are correct
|
|
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
|
# args.max_token_length,
|
|
# batch["input_ids"].to(text_encoders[0].device),
|
|
# batch["input_ids2"].to(text_encoders[0].device),
|
|
# tokenizers[0],
|
|
# tokenizers[1],
|
|
# text_encoders[0],
|
|
# text_encoders[1],
|
|
# None if not args.full_fp16 else weight_dtype,
|
|
# )
|
|
# b_size = encoder_hidden_states1.shape[0]
|
|
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
|
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
|
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
|
# logger.info("text encoder outputs verified")
|
|
|
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
|
|
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
|
|
|
# get size embeddings
|
|
orig_size = batch["original_sizes_hw"]
|
|
crop_size = batch["crop_top_lefts"]
|
|
target_size = batch["target_sizes_hw"]
|
|
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
|
|
|
# concat embeddings
|
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
|
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
|
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
|
|
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
|
return noise_pred
|
|
|
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
|
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = train_network.setup_parser()
|
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
|
|
args = parser.parse_args()
|
|
train_util.verify_command_line_training_args(args)
|
|
args = train_util.read_config_from_file(args, parser)
|
|
|
|
trainer = SdxlNetworkTrainer()
|
|
trainer.train(args)
|