From f56988b2529a21febe3d784b1fccf13bd0e7df27 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Jan 2023 08:10:22 +0900 Subject: [PATCH] unify dataset and save functions --- fine_tune.py | 300 ++-------------- library/model_util.py | 8 - library/train_util.py | 153 ++++++-- train_db.py | 788 +++++------------------------------------- train_network.py | 54 ++- 5 files changed, 287 insertions(+), 1016 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 5da37b68..8b06abda 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -1,27 +1,17 @@ # training with captions -# XXX dropped option: fine_tune +# XXX dropped option: hypernetwork training import argparse import gc import math import os -import random -import json -import importlib -import time from tqdm import tqdm import torch -from accelerate import Accelerator from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import numpy as np -from einops import rearrange -from torch import einsum +from diffusers import DDPMScheduler -import library.model_util as model_util import library.train_util as train_util @@ -29,211 +19,21 @@ def collate_fn(examples): return examples[0] -class FineTuningDataset(torch.utils.data.Dataset): - def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, dataset_repeats, debug) -> None: - super().__init__() - - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - self.debug = debug - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - print("make buckets") - - # 最初に数を数える - self.bucket_resos = set() - for img_md in metadata.values(): - if 'train_resolution' in img_md: - self.bucket_resos.add(tuple(img_md['train_resolution'])) - self.bucket_resos = list(self.bucket_resos) - self.bucket_resos.sort() - print(f"number of buckets: {len(self.bucket_resos)}") - - reso_to_index = {} - for i, reso in enumerate(self.bucket_resos): - reso_to_index[reso] = i - - # bucketに割り当てていく - self.buckets = [[] for _ in range(len(self.bucket_resos))] - n = 1 if dataset_repeats is None else dataset_repeats - images_count = 0 - for image_key, img_md in metadata.items(): - if 'train_resolution' not in img_md: - continue - if not os.path.exists(self.image_key_to_npz_file(image_key)): - continue - - reso = tuple(img_md['train_resolution']) - for _ in range(n): - self.buckets[reso_to_index[reso]].append(image_key) - images_count += n - - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - self.images_count = images_count - - def show_buckets(self): - for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def image_key_to_npz_file(self, image_key): - npz_file_norm = os.path.splitext(image_key)[0] + '.npz' - if os.path.exists(npz_file_norm): - if random.random() < .5: - npz_file_flip = os.path.splitext(image_key)[0] + '_flip.npz' - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm - - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - if random.random() < .5: - npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm - - def load_latent(self, image_key): - return np.load(self.image_key_to_npz_file(image_key))['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size - - input_ids_list = [] - latents_list = [] - captions = [] - for image_key in bucket[image_index:image_index + self.batch_size]: - img_md = self.metadata[image_key] - caption = img_md.get('caption') - tags = img_md.get('tags') - - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}" - - latents = self.load_latent(image_key) - - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() - - captions.append(caption) - - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - - input_ids_list.append(input_ids) - latents_list.append(torch.FloatTensor(latents)) - - example = {} - example['input_ids'] = torch.stack(input_ids_list) - example['latents'] = torch.stack(latents_list) - if self.debug: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) cache_latents = args.cache_latents - # verify load/save model formats - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) - - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する tokenizer = train_util.load_tokenizer(args) - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, + tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() if args.debug_dataset: @@ -253,6 +53,21 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + # Diffusers版のxformers使用フラグを設定する関数 def set_diffusers_xformers_flag(model, valid): # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう @@ -308,7 +123,11 @@ def train(args): else: text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) # text encoderは学習しない - text_encoder.eval() + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() if not cache_latents: vae.requires_grad_(False) @@ -365,12 +184,7 @@ def train(args): # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -413,7 +227,6 @@ def train(args): latents = latents * 0.18215 b_size = latents.shape[0] - # with torch.no_grad(): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) @@ -435,7 +248,6 @@ def train(args): if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -478,63 +290,26 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - def save_func(file): - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func) - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - - if save_stable_diffusion_format: - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) is_main_process = accelerator.is_main_process if is_main_process: - if fine_tuning: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - else: - hypernetwork = unwrap_model(hypernetwork) + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) accelerator.end_training() if args.save_state: - print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) - - if fine_tuning: - if save_stable_diffusion_format: - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) - else: - # Create the pipeline using using the trained modules and save it. - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - else: - print(f"save trained model to {ckpt_file}") - save_hypernetwork(ckpt_file, hypernetwork) - + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") @@ -544,9 +319,8 @@ if __name__ == '__main__': train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True) train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') diff --git a/library/model_util.py b/library/model_util.py index 398b6404..bc824a12 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1133,14 +1133,6 @@ def load_vae(vae_id, dtype): return vae -def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") - - -def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") - - # endregion diff --git a/library/train_util.py b/library/train_util.py index 8eedf48c..5033a55b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,7 +1,9 @@ # common functions for training +# TODO test no_token_padding option import argparse import json +import shutil import time from typing import NamedTuple from accelerate import Accelerator @@ -31,18 +33,16 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -EPOCH_FILE_NAME = "epoch-{:06d}" -LAST_FILE_NAME = "last" - -LAST_DIFFUSERS_DIR_NAME = "last" -EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" - +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset + class ImageInfo(): def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: self.image_key: str = image_key @@ -76,6 +76,7 @@ class BaseDataset(torch.utils.data.Dataset): self.flip_aug = flip_aug self.color_aug = color_aug self.debug_dataset = debug_dataset + self.padding_disabled = False self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -101,6 +102,9 @@ class BaseDataset(torch.utils.data.Dataset): self.image_data: dict[str, ImageInfo] = {} + def disable_padding(self): + self.padding_disabled = True + def process_caption(self, caption): if self.shuffle_caption: tokens = caption.strip().split(",") @@ -408,11 +412,18 @@ class BaseDataset(torch.utils.data.Dataset): caption = self.process_caption(image_info.caption) captions.append(caption) - input_ids_list.append(self.get_input_ids(caption)) + if not self.padding_disabled: # this option might be omitted in future + input_ids_list.append(self.get_input_ids(caption)) example = {} example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = torch.stack(input_ids_list) + + if self.padding_disabled: + # padding=True means pad in the batch + example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example['input_ids'] = torch.stack(input_ids_list) if images[0] is not None: images = torch.stack(images) @@ -664,6 +675,7 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip + def debug_dataset(train_dataset): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") @@ -973,12 +985,13 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, + help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument("--save_precision", type=str, default=None, choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1034,6 +1047,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b parser.add_argument("--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") parser.add_argument("--keep_tokens", type=int, default=None, help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") @@ -1064,7 +1079,19 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") -def prepare_dataset_args(args: argparse.Namespace, support_caption: bool): +def add_sd_saving_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") + parser.add_argument("--use_safetensors", action='store_true', + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") + + +def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None + if args.cache_latents: assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" @@ -1083,7 +1110,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_caption: bool): else: args.face_crop_aug_range = None - if support_caption: + if support_metadata: if args.in_json is not None and args.color_aug: print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") @@ -1216,29 +1243,95 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod return encoder_hidden_states -def save_on_epoch_end(args: argparse.Namespace, accelerator, epoch: int, num_train_epochs: int, save_func): - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: +def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name + + +def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + remove_epoch_no = None + if saving: print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) - save_func(ckpt_file) + save_func() - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving, remove_epoch_no -def save_last_state(args, accelerator): +def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) + + if save_stable_diffusion_format: + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch_no, global_step, save_dtype, vae) + + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) + + save_func = save_sd + remove_old_func = remove_sd + else: + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) + + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator): print("saving last state.") os.makedirs(args.output_dir, exist_ok=True) - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) -def save_last_model(args, save_func): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) - print(f"save trained model to {ckpt_file}") - save_func(ckpt_file) - print("model saved.") +def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) + + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch, global_step, save_dtype, vae) + else: + print(f"save trained model as Diffusers to {args.output_dir}") + + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) + + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) # endregion diff --git a/train_db.py b/train_db.py index 4a30d2d5..a03ab563 100644 --- a/train_db.py +++ b/train_db.py @@ -1,4 +1,5 @@ # DreamBooth training +# XXX dropped option: fine_tune import gc import time @@ -31,364 +32,49 @@ import library.train_util as train_util from library.train_util import DreamBoothDataset, FineTuningDataset -# region dataset - -class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): - def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None: - super().__init__() - - self.batch_size = batch_size - self.fine_tuning = fine_tuning - self.train_img_path_captions = train_img_path_captions - self.reg_img_path_captions = reg_img_path_captions - self.tokenizer = tokenizer - self.width, self.height = resolution - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.face_crop_aug_range = face_crop_aug_range - self.random_crop = random_crop - self.debug_dataset = debug_dataset - self.shuffle_caption = shuffle_caption - self.disable_padding = disable_padding - self.latents_cache = None - self.enable_bucket = False - - # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る - self.aug = albu.Compose([ - albu.OneOf([ - # albu.RandomBrightnessContrast(0.05, 0.05, p=.2), - albu.HueSaturationValue(5, 8, 0, p=.2), - # albu.RGBShift(5, 5, 5, p=.1), - albu.RandomGamma((95, 105), p=.5), - ], p=.33), - albu.HorizontalFlip(p=flip_p) - ], p=1.) - elif flip_aug: - self.aug = albu.Compose([ - albu.HorizontalFlip(p=flip_p) - ], p=1.) - else: - self.aug = None - - self.num_train_images = len(self.train_img_path_captions) - self.num_reg_images = len(self.reg_img_path_captions) - - self.enable_reg_images = self.num_reg_images > 0 - - if self.enable_reg_images and self.num_train_images < self.num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - self.image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size): - self.enable_bucket = enable_bucket - - cache_latents = vae is not None - if cache_latents: - if enable_bucket: - print("cache latents with bucketing") - else: - print("cache latents") - else: - if enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketingを用意する - if enable_bucket: - bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) - else: - # bucketはひとつだけ、すべての画像は同じ解像度 - bucket_resos = [(self.width, self.height)] - bucket_aspect_ratios = [self.width / self.height] - bucket_aspect_ratios = np.array(bucket_aspect_ratios) - - # 画像の解像度、latentをあらかじめ取得する - img_ar_errors = [] - self.size_lat_cache = {} - for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions): - if image_path in self.size_lat_cache: - continue - - image = self.load_image(image_path)[0] - image_height, image_width = image.shape[0:2] - - if not enable_bucket: - # assert image_width == self.width and image_height == self.height, \ - # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}" - reso = (self.width, self.height) - else: - # bucketを決める - aspect_ratio = image_width / image_height - ar_errors = bucket_aspect_ratios - aspect_ratio - bucket_id = np.abs(ar_errors).argmin() - reso = bucket_resos[bucket_id] - ar_error = ar_errors[bucket_id] - img_ar_errors.append(ar_error) - - if cache_latents: - image = self.resize_and_trim(image, reso) - - # latentを取得する - if cache_latents: - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - else: - latents = None - - self.size_lat_cache[image_path] = (reso, latents) - - # 画像をbucketに分割する - self.buckets = [[] for _ in range(len(bucket_resos))] - reso_to_index = {} - for i, reso in enumerate(bucket_resos): - reso_to_index[reso] = i - - def split_to_buckets(is_reg, img_path_captions): - for image_path, caption in img_path_captions: - reso, _ = self.size_lat_cache[image_path] - bucket_index = reso_to_index[reso] - self.buckets[bucket_index].append((is_reg, image_path, caption)) - - split_to_buckets(False, self.train_img_path_captions) - - if self.enable_reg_images: - l = [] - while len(l) < len(self.train_img_path_captions): - l += self.reg_img_path_captions - l = l[:len(self.train_img_path_captions)] - split_to_buckets(True, l) - - if enable_bucket: - print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数") - for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(imgs)}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}") - - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - # どのサイズにリサイズするか→トリミングする方向で - def resize_and_trim(self, image, reso): - image_height, image_width = image.shape[0:2] - ar_img = image_width / image_height - ar_reso = reso[0] / reso[1] - if ar_img > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] - elif resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ - f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - - face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへを切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if self.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size - - latents_list = [] - images = [] - captions = [] - loss_weights = [] - - for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]: - loss_weights.append(self.prior_loss_weight if is_reg else 1.0) - - # image/latentsを処理する - reso, latents = self.size_lat_cache[image_path] - - if latents is None: - # 画像を読み込み必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image(image_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.resize_and_trim(img, reso) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}" - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - else: - image = None - - images.append(image) - latents_list.append(latents) - - # captionを処理する - if self.shuffle_caption: # captionのshuffleをする - tokens = caption.strip().split(",") - random.shuffle(tokens) - caption = ",".join(tokens).strip() - captions.append(caption) - - # input_idsをpadしてTensor変換 - if self.disable_padding: - # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?) - input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # paddingする - input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids - - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = input_ids - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None - if self.debug_dataset: - example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]] - example['captions'] = captions - return example -# endregion - - def collate_fn(examples): return examples[0] def train(args): - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) - fine_tuning = args.fine_tuning cache_latents = args.cache_latents - # latentsをキャッシュする場合のオプション設定を確認する - if cache_latents: - assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません" + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + tokenizer = train_util.load_tokenizer(args) - # モデル形式のオプション設定を確認する: - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) + train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, + tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: + train_dataset.disable_padding() + train_dataset.make_buckets() + if args.debug_dataset: + train_util.debug_dataset(train_dataset) + + # acceleratorを準備する + print("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") + + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats if load_stable_diffusion_format: src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_diffusers_model_path = None @@ -403,202 +89,6 @@ def train(args): save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 乱数系列を初期化する - if args.seed is not None: - set_seed(args.seed) - - # 学習データを用意する - def read_caption(img_path): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + args.caption_extension, base_name_face_det + args.caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(dir): - tokens = os.path.basename(dir).split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - return 0, [] - - caption_by_folder = '_'.join(tokens[1:]) - - print(f"found directory {n_repeats}_{caption_by_folder}") - - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う(v11から仕様変更した) - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - - return n_repeats, list(zip(img_paths, captions)) - - print("prepare train images.") - train_img_path_captions = [] - - if fine_tuning: - img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) - for img_path in tqdm(img_paths): - caption = read_caption(img_path) - assert caption is not None and len( - caption) > 0, f"no caption for image. check caption_extension option / キャプションファイルが見つからないかcaptionが空です。caption_extensionオプションを確認してください: {img_path}" - - train_img_path_captions.append((img_path, caption)) - - if args.dataset_repeats is not None: - l = [] - for _ in range(args.dataset_repeats): - l.extend(train_img_path_captions) - train_img_path_captions = l - else: - train_dirs = os.listdir(args.train_data_dir) - for dir in train_dirs: - n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir)) - for _ in range(n_repeats): - train_img_path_captions.extend(img_caps) - print(f"{len(train_img_path_captions)} train images with repeating.") - - reg_img_path_captions = [] - if args.reg_data_dir: - print("prepare reg images.") - reg_dirs = os.listdir(args.reg_data_dir) - for dir in reg_dirs: - n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir)) - for _ in range(n_repeats): - reg_img_path_captions.extend(img_caps) - print(f"{len(reg_img_path_captions)} reg images.") - - # データセットを準備する - resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(resolution) == 1: - resolution = (resolution[0], resolution[0]) - assert len(resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - - if args.enable_bucket: - assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください" - assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください" - - if args.face_crop_aug_range is not None: - face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len( - face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - face_crop_aug_range = None - - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH) - - print("prepare dataset") - train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, - args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, - args.shuffle_caption, args.no_token_padding, args.debug_dataset) - - if args.debug_dataset: - train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, - args.max_bucket_reso) # デバッグ用にcacheなしで作る - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") - for example in train_dataset: - for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']): - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}') - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27: - break - return - - # acceleratorを準備する - # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする - print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 - - # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe - - # # 置換するCLIPを読み込む - # if args.replace_clip_l14_336: - # text_encoder = load_clip_l14_336(weight_dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") - # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -608,23 +98,29 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso) + train_dataset.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - else: - train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso) - vae.requires_grad_(False) - vae.eval() + # 学習を準備する:モデルを適切な状態にする + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + train_text_encoder = args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(True) + text_encoder.requires_grad_(train_text_encoder) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") @@ -639,7 +135,10 @@ def train(args): else: optimizer_class = torch.optim.AdamW - trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) + if train_text_encoder: + trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = unet.parameters() # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 optimizer = optimizer_class(trainable_params, lr=args.learning_rate) @@ -662,20 +161,15 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - - if not cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -683,7 +177,8 @@ def train(args): accelerator.load_state(args.resume) # epoch数を計算する - num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader)) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # 学習する total_batch_size = args.train_batch_size # * accelerator.num_processes @@ -700,33 +195,28 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # v12で更新:clip_sample=Falseに - # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')  noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: accelerator.init_trackers("dreambooth") - # 以下 train_dreambooth.py からほぼコピペ for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training unet.train() - if train_text_encoder: + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() loss_total = 0 for step, batch in enumerate(train_dataloader): # 指定したステップ数でText Encoderの学習を止める - stop_text_encoder_training = args.stop_text_encoder_training is not None and global_step == args.stop_text_encoder_training - if stop_text_encoder_training: + if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") - text_encoder.train(False) + if not args.gradient_checkpointing: + text_encoder.train(False) text_encoder.requires_grad_(False) with accelerator.accumulate(unet): @@ -742,6 +232,11 @@ def train(args): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) + # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() @@ -750,20 +245,11 @@ def train(args): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - if args.clip_skip is None: - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - else: - enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -778,7 +264,10 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) + if train_text_encoder: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + params_to_clip = unet.parameters() accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) optimizer.step() @@ -810,35 +299,9 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - if args.save_last_n_epochs is not None: - old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) - if os.path.exists(old_ckpt_file): - os.remove(old_ckpt_file) - else: - out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), - unwrap_model(unet), src_diffusers_model_path, - use_safetensors=use_safetensors) - if args.save_last_n_epochs is not None: - out_dir_old = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) - if os.path.exists(out_dir_old): - shutil.rmtree(out_dir_old) - - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) - if args.save_last_n_epochs is not None: - state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) - if os.path.exists(state_dir_old): - shutil.rmtree(state_dir_old) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) is_main_process = accelerator.is_main_process if is_main_process: @@ -854,107 +317,24 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - if save_stable_diffusion_format: - ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) - else: - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path, - use_safetensors=use_safetensors) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument("--fine_tuning", action="store_true", - help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - parser.add_argument("--dataset_repeats", type=int, default=None, - help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_last_n_epochs", type=int, default=None, - help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") parser.add_argument("--stop_text_encoder_training", type=int, default=None, - help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--train_batch_size", type=int, default=1, - help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") args = parser.parse_args() train(args) - diff --git a/train_network.py b/train_network.py index 35f50567..bfb2d860 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,7 @@ import gc import importlib import json +import shutil import time import argparse import math @@ -143,8 +144,6 @@ def train(args): if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" print("enable full fp16 training.") - # unet.to(weight_dtype) - # text_encoder.to(weight_dtype) network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -163,10 +162,14 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) - unet.eval() text_encoder.requires_grad_(False) text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.eval() + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + else: + unet.eval() + text_encoder.eval() network.prepare_grad_etc(text_encoder, unet) @@ -294,9 +297,29 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - def save_func(file): - unwrap_model(network).save_weights(file, save_dtype) - train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func) + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + unwrap_model(network).save_weights(ckpt_file, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) + + saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) + + # end of epoch is_main_process = accelerator.is_main_process if is_main_process: @@ -305,14 +328,20 @@ def train(args): accelerator.end_training() if args.save_state: - train_util.save_last_state(args, accelerator) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - def last_save_func(file): - network.save_weights(file, save_dtype) - train_util.save_last_model(args, last_save_func) + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + '.' + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype) + print("model saved.") if __name__ == '__main__': @@ -322,6 +351,9 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True) train_util.add_training_arguments(parser, True) + parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")