split common function from train_network to util

This commit is contained in:
Kohya S
2023-01-03 20:22:25 +09:00
parent e31177adf3
commit 4c35006731
4 changed files with 460 additions and 563 deletions

View File

@@ -1,6 +1,8 @@
# training with captions
# XXX dropped option: fine_tune
import argparse
import gc
import math
import os
import random
@@ -200,21 +202,13 @@ class FineTuningDataset(torch.utils.data.Dataset):
return example
def save_hypernetwork(output_file, hypernetwork):
state_dict = hypernetwork.get_state_dict()
torch.save(state_dict, output_file)
def train(args):
fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
# その他のオプション設定を確認する
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
cache_latents = args.cache_latents
# モデル形式のオプション設定を確認する
# verify load/save model formats
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
if load_stable_diffusion_format:
@@ -231,109 +225,33 @@ def train(args):
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# 乱数系列を初期化する
if args.seed is not None:
set_seed(args.seed)
set_seed(args.seed) # 乱数系列を初期化する
# メタデータを読み込む
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
tokenizer = train_util.load_tokenizer(args)
# tokenizerを読み込む
print("prepare tokenizer")
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH)
if args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
# datasetを用意する
print("prepare dataset")
train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.dataset_repeats, args.debug_dataset)
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
print(f"Total images / 画像数: {train_dataset.images_count}")
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
train_dataset.make_buckets()
if args.debug_dataset:
train_util.debug_dataset(train_dataset)
return
if len(train_dataset) == 0:
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
return
if args.debug_dataset:
train_dataset.show_buckets()
i = 0
for example in train_dataset:
print(f"image: {example['image_keys']}")
print(f"captions: {example['captions']}")
print(f"latents: {example['latents'].shape}")
print(f"input_ids: {example['input_ids'].shape}")
print(example['input_ids'])
i += 1
if i >= 8:
break
return
# acceleratorを準備する
print("prepare accelerator")
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
# , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
text_encoder = pipe.text_encoder
unet = pipe.unet
vae = pipe.vae
del pipe
vae.to("cpu") # 保存時にしか使わないので、メモリを開けるためCPUに移しておく
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
@@ -364,46 +282,38 @@ def train(args):
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
if not fine_tuning:
# Hypernetwork
print("import hypernetwork module:", args.hypernetwork_module)
hyp_module = importlib.import_module(args.hypernetwork_module)
hypernetwork = hyp_module.Hypernetwork()
if args.hypernetwork_weights is not None:
print("load hypernetwork weights from:", args.hypernetwork_weights)
hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu')
success = hypernetwork.load_from_state_dict(hyp_sd)
assert success, "hypernetwork weights loading failed."
print("apply hypernetwork")
hypernetwork.apply_to_diffusers(None, text_encoder, unet)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if fine_tuning:
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
if args.train_text_encoder:
print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) # text encoderは学習しない
text_encoder.eval()
if args.train_text_encoder:
print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
else:
unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない
unet.requires_grad_(False)
unet.eval()
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.requires_grad_(False) # text encoderは学習しない
text_encoder.eval()
training_models.append(hypernetwork)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models:
m.requires_grad_(True)
@@ -439,29 +349,19 @@ def train(args):
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# acceleratorがなんかよろしくやってくれるらしい
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
if fine_tuning:
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
if args.full_fp16:
unet.to(weight_dtype)
hypernetwork.to(weight_dtype)
unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, hypernetwork, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -472,8 +372,6 @@ def train(args):
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
# TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
@@ -497,17 +395,12 @@ def train(args):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
# v4で更新clip_sample=Falseに
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
# 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigはクラス名を除いて同じ
# よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
accelerator.init_trackers("finetuning")
# 以下 train_dreambooth.py からほぼコピペ
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
for m in training_models:
@@ -524,38 +417,7 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
if args.clip_skip is None:
encoder_hidden_states = text_encoder(input_ids)[0]
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
if args.v2:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
@@ -616,23 +478,23 @@ def train(args):
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
def save_func(file):
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet),
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func)
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
if fine_tuning:
if save_stable_diffusion_format:
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
else:
out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet),
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
if save_stable_diffusion_format:
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
else:
save_hypernetwork(ckpt_file, unwrap_model(hypernetwork))
out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet),
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1)))
@@ -677,73 +539,16 @@ def train(args):
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--v_parameterization", action='store_true',
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
parser.add_argument("--shuffle_caption", action="store_true",
help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
parser.add_argument("--keep_tokens", type=int, default=None,
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数")
parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存するStableDiffusion形式での保存時のみ有効")
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True)
train_util.add_training_arguments(parser, False)
parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument("--hypernetwork_module", type=str, default=None,
help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール')
parser.add_argument("--hypernetwork_weights", type=str, default=None,
help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重みHypernetworkの追加学習')
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None,
help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--train_batch_size", type=int, default=1,
help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用するHypernetwork利用不可')
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
help='use xformers by diffusers / Diffusersでxformersを使用する')
args = parser.parse_args()
train(args)

View File

@@ -1,7 +1,10 @@
# common functions for training
import argparse
import json
import time
from typing import NamedTuple
from accelerate import Accelerator
from torch.autograd.function import Function
import glob
import math
@@ -13,6 +16,7 @@ import torch
from torchvision import transforms
from transformers import CLIPTokenizer
import diffusers
from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
from PIL import Image
@@ -33,6 +37,9 @@ LAST_STATE_NAME = "last-state"
EPOCH_FILE_NAME = "epoch-{:06d}"
LAST_FILE_NAME = "last"
LAST_DIFFUSERS_DIR_NAME = "last"
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
# region dataset
@@ -63,7 +70,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.max_token_length = max_token_length
self.shuffle_caption = shuffle_caption
self.shuffle_keep_tokens = shuffle_keep_tokens
self.width, self.height = resolution
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.face_crop_aug_range = face_crop_aug_range
self.flip_aug = flip_aug
self.color_aug = color_aug
@@ -149,35 +157,26 @@ class BaseDataset(torch.utils.data.Dataset):
def register_image(self, info: ImageInfo):
self.image_data[info.image_key] = info
def make_buckets(self, enable_bucket, min_size, max_size):
def make_buckets(self):
'''
bucketingを行わない場合も呼び出し必須ひとつだけbucketを作る
min_size and max_size are ignored when enable_bucket is False
'''
self.enable_bucket = enable_bucket
print("loading image sizes.")
for info in tqdm(self.image_data.values()):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
if enable_bucket:
if self.enable_bucket:
print("make buckets")
else:
print("prepare dataset")
# bucketingを用意する
if enable_bucket:
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size)
else:
# bucketはひとつだけ、すべての画像は同じ解像度
bucket_resos = [(self.width, self.height)]
bucket_aspect_ratios = [self.width / self.height]
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
bucket_resos = self.bucket_resos
bucket_aspect_ratios = np.array(self.bucket_aspect_ratios)
# bucketを作成する
if enable_bucket:
if self.enable_bucket:
img_ar_errors = []
for image_info in self.image_data.values():
# bucketを決める
@@ -191,9 +190,8 @@ class BaseDataset(torch.utils.data.Dataset):
ar_error = ar_errors[bucket_id]
img_ar_errors.append(ar_error)
else:
reso = (self.width, self.height)
for image_info in self.image_data.values():
image_info.bucket_reso = reso
image_info.bucket_reso = bucket_resos[0] # bucket_resos contains (width, height) only
# 画像をbucketに分割する
self.buckets: list[str] = [[] for _ in range(len(bucket_resos))]
@@ -206,8 +204,8 @@ class BaseDataset(torch.utils.data.Dataset):
for _ in range(image_info.num_repeats):
self.buckets[bucket_index].append(image_info.image_key)
if enable_bucket:
print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数DreamBoothの場合は繰り返し回数を含む)")
if self.enable_bucket:
print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
img_ar_errors = np.array(img_ar_errors)
@@ -432,16 +430,27 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.random_crop = random_crop
self.latents_cache = None
self.enable_bucket = False
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso)
else:
self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height]
def read_caption(img_path):
# captionの候補ファイル名を作る
@@ -532,9 +541,9 @@ class DreamBoothDataset(BaseDataset):
class FineTuningDataset(BaseDataset):
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
# メタデータを読み込む
if os.path.exists(json_file_name):
@@ -602,18 +611,35 @@ class FineTuningDataset(BaseDataset):
# check min/max bucket size
sizes = set()
resos = set()
for image_info in self.image_data.values():
if image_info.image_size is None:
sizes = None # not calculated
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
resos.add(image_info.image_size)
if sizes is None:
self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso)
else:
self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height]
else:
self.min_bucket_reso = min(sizes)
self.max_bucket_reso = max(sizes)
if not enable_bucket:
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
self.enable_bucket = True
self.bucket_resos = list(resos)
self.bucket_resos.sort()
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
def image_key_to_npz_file(self, image_key):
base_name = os.path.splitext(image_key)[0]
@@ -638,6 +664,28 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
def debug_dataset(train_dataset):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
k = 0
for example in train_dataset:
if example['latents'] is not None:
print("sample has latents from npz file")
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
if example['images'] is not None:
im = example['images'][j]
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27:
break
if k == 27 or example['images'] is None:
break
# endregion
@@ -908,3 +956,289 @@ def replace_unet_cross_attn_to_xformers():
diffusers.models.attention.CrossAttention.forward = forward_xformers
# endregion
# region utils
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--v_parameterization", action='store_true',
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--vae", type=str, default=None,
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
if support_dreambooth:
# DreamBooth training
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
help="loss weight for regularization images / 正則化画像のlossの重み")
def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
# dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--shuffle_caption", action="store_true",
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
parser.add_argument("--keep_tokens", type=int, default=None,
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument("--face_crop_aug_range", type=str, default=None,
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する2.0,4.0")
parser.add_argument("--random_crop", action="store_true",
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする顔を中心としたaugmentationを行うときに画風の学習用に指定する")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--resolution", type=str, default=None,
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
parser.add_argument("--cache_latents", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可")
parser.add_argument("--enable_bucket", action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
if support_dreambooth:
# DreamBooth dataset
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
if support_caption:
# caption dataset
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
parser.add_argument("--dataset_repeats", type=int, default=1,
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
def prepare_dataset_args(args: argparse.Namespace, support_caption: bool):
if args.cache_latents:
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
# assert args.resolution is not None, f"resolution is required / resolution解像度を指定してください"
if args.resolution is not None:
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
if len(args.resolution) == 1:
args.resolution = (args.resolution[0], args.resolution[0])
assert len(args.resolution) == 2, \
f"resolution must be 'size' or 'width,height' / resolution解像度'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.face_crop_aug_range is not None:
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
assert len(args.face_crop_aug_range) == 2, \
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
else:
args.face_crop_aug_range = None
if support_caption:
if args.in_json is not None and args.color_aug:
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
def load_tokenizer(args: argparse.Namespace):
print("prepare tokenizer")
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
if args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
return tokenizer
def prepare_accelerator(args: argparse.Namespace):
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
log_with=log_with, logging_dir=logging_dir)
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
return accelerator, unwrap_model
def prepare_dtype(args: argparse.Namespace):
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype):
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
del pipe
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
print("additional VAE loaded")
return text_encoder, vae, unet, load_stable_diffusion_format
def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
if args.clip_skip is None:
encoder_hidden_states = text_encoder(input_ids)[0]
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
if args.v2:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
return encoder_hidden_states
def save_on_epoch_end(args: argparse.Namespace, accelerator, epoch: int, num_train_epochs: int, save_func):
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as)
save_func(ckpt_file)
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
def save_last_state(args, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
def save_last_model(args, save_func):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as)
print(f"save trained model to {ckpt_file}")
save_func(ckpt_file)
print("model saved.")
# endregion

View File

@@ -832,11 +832,6 @@ def train(args):
if os.path.exists(out_dir_old):
shutil.rmtree(out_dir_old)
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1)))

View File

@@ -8,7 +8,6 @@ import os
from tqdm import tqdm
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import CLIPTokenizer
import diffusers
@@ -26,165 +25,48 @@ def collate_fn(examples):
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents
# latentsをキャッシュする場合のオプション設定を確認する
if cache_latents:
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
# その他のオプション設定を確認する
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
use_dreambooth_method = args.in_json is None
# モデル形式のオプション設定を確認する:
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
# 乱数系列を初期化する
if args.seed is not None:
set_seed(args.seed)
# tokenizerを読み込む
print("prepare tokenizer")
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(train_util. V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(train_util. TOKENIZER_PATH)
if args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
# 学習データを用意する
assert args.resolution is not None, f"resolution is required / resolution解像度を指定してください"
resolution = tuple([int(r) for r in args.resolution.split(',')])
if len(resolution) == 1:
resolution = (resolution[0], resolution[0])
assert len(resolution) == 2, \
f"resolution must be 'size' or 'width,height' / resolution解像度'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.face_crop_aug_range is not None:
face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
assert len(
face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
else:
face_crop_aug_range = None
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if use_dreambooth_method:
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.debug_dataset)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
if args.color_aug:
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
if train_dataset.min_bucket_reso is not None and (args.enable_bucket or train_dataset.min_bucket_reso != train_dataset.max_bucket_reso):
print(f"using bucket info in metadata / メタデータ内のbucket情報を使います")
args.min_bucket_reso = train_dataset.min_bucket_reso
args.max_bucket_reso = train_dataset.max_bucket_reso
args.enable_bucket = True
print(f"min bucket reso: {args.min_bucket_reso}, max bucket reso: {args.max_bucket_reso}")
if args.enable_bucket:
assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
train_dataset.make_buckets(args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
train_dataset.make_buckets()
if args.debug_dataset:
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
k = 0
for example in train_dataset:
if example['latents'] is not None:
print("sample has latents from npz file")
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
if example['images'] is not None:
im = example['images'][j]
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27:
break
if k == 27 or example['images'] is None:
break
train_util.debug_dataset(train_dataset)
return
if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
# acceleratorを準備する
print("prepare accelerator")
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
log_with=log_with, logging_dir=logging_dir)
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
del pipe
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
print("additional VAE loaded")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -295,12 +177,7 @@ def train(args):
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
@@ -353,39 +230,7 @@ def train(args):
with torch.set_grad_enabled(train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
if args.clip_skip is None:
encoder_hidden_states = text_encoder(input_ids)[0]
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # なぜかこれが必要
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
if args.v2:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
@@ -403,7 +248,6 @@ def train(args):
if args.v_parameterization:
# v-parameterization training
# Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
@@ -450,15 +294,9 @@ def train(args):
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, train_util.EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as)
unwrap_model(network).save_weights(ckpt_file, save_dtype)
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1)))
def save_func(file):
unwrap_model(network).save_weights(file, save_dtype)
train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func)
is_main_process = accelerator.is_main_process
if is_main_process:
@@ -467,103 +305,28 @@ def train(args):
accelerator.end_training()
if args.save_state:
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME))
train_util.save_last_state(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, train_util.LAST_FILE_NAME + '.' + args.save_model_as)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype)
print("model saved.")
def last_save_func(file):
network.save_weights(file, save_dtype)
train_util.save_last_model(args, last_save_func)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--v_parameterization", action='store_true',
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--shuffle_caption", action="store_true",
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
parser.add_argument("--keep_tokens", type=int, default=None,
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
parser.add_argument("--dataset_repeats", type=int, default=1,
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument("--face_crop_aug_range", type=str, default=None,
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する2.0,4.0")
parser.add_argument("--random_crop", action="store_true",
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする顔を中心としたaugmentationを行うときに画風の学習用に指定する")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--resolution", type=str, default=None,
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--vae", type=str, default=None,
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
parser.add_argument("--cache_latents", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可")
parser.add_argument("--enable_bucket", action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True)
train_util.add_training_arguments(parser, True)
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
# parser.add_argument("--stop_text_encoder_training", type=int, default=None,
# help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')