diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index bf9b2b39..a89aed0f 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1,7 +1,7 @@ # txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc... # v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix -# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue hypernetwork_mul is not working +# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working # v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode # v5: fix clip_sample=True for scheduler, add VGG guidance # v6: refactor to use model util, load VAE without vae folder, support safe tensors @@ -333,7 +333,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use Hypernetwork and FlashAttention") + print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") flash_func = FlashAttentionFunction def forward_flash_attn(self, x, context=None, mask=None): @@ -373,7 +373,7 @@ def replace_unet_cross_attn_to_memory_efficient(): def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use Hypernetwork and xformers") + print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") try: import xformers.ops except ImportError: @@ -1867,25 +1867,6 @@ def main(args): if not args.diffusers_xformers: replace_unet_modules(unet, not args.xformers, args.xformers) - # hypernetworkを組み込む - if args.hypernetwork_module is not None: - assert not args.diffusers_xformers, "cannot use hypernetwork with diffusers_xformers / diffusers_xformers指定時はHypernetworkは利用できません" - - print("import hypernetwork module:", args.hypernetwork_module) - hyp_module = importlib.import_module(args.hypernetwork_module) - - hypernetwork = hyp_module.Hypernetwork(args.hypernetwork_mul) - - print("load hypernetwork weights from:", args.hypernetwork_weights) - hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') - success = hypernetwork.load_from_state_dict(hyp_sd) - assert success, "hypernetwork weights loading failed." - - if args.opt_channels_last: - hypernetwork.to(memory_format=torch.channels_last) - else: - hypernetwork = None - # tokenizerを読み込む print("loading tokenizer") if use_stable_diffusion_format: @@ -2000,10 +1981,27 @@ def main(args): if vgg16_model is not None: vgg16_model.to(dtype).to(device) - if hypernetwork is not None: - hypernetwork.to(dtype).to(device) - print("apply hypernetwork") - hypernetwork.apply_to_diffusers(vae, text_encoder, unet) + # networkを組み込む + if args.network_module is not None: + # assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません" + + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs) + if network is None: + return + + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) + + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + else: + network = None if args.opt_channels_last: print(f"set optimizing: channels last") @@ -2012,8 +2010,8 @@ def main(args): unet.to(memory_format=torch.channels_last) if clip_model is not None: clip_model.to(memory_format=torch.channels_last) - if hypernetwork is not None: - hypernetwork.to(memory_format=torch.channels_last) + if network is not None: + network.to(memory_format=torch.channels_last) if vgg16_model is not None: vgg16_model.to(memory_format=torch.channels_last) @@ -2491,9 +2489,11 @@ if __name__ == '__main__': help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') parser.add_argument("--opt_channels_last", action='store_true', help='set channels last option to model / モデルにchannles lastを指定し最適化する') - parser.add_argument("--hypernetwork_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') - parser.add_argument("--hypernetwork_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') - parser.add_argument("--hypernetwork_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') + parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') + parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + parser.add_argument("--network_dim", type=int, default=None, + help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--max_embeddings_multiples", type=int, default=None, help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') diff --git a/networks/lora.py b/networks/lora.py new file mode 100644 index 00000000..730a6376 --- /dev/null +++ b/networks/lora.py @@ -0,0 +1,190 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +import torch + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == 'Conv2d': + in_dim = org_module.in_channels + out_dim = org_module.out_channels + self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) + self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier + + +def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): + if network_dim is None: + network_dim = 4 # default + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) + return network + + +class LoRANetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: + super().__init__() + self.multiplier = multiplier + self.lora_dim = lora_dim + + # create module instances + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) + loras.append(lora) + return loras + + self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, + text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + self.weights_sd = None + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def load_weights(self, file): + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location='cpu') + + def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): + if self.weights_sd: + weights_has_text_encoder = weights_has_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + weights_has_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + weights_has_unet = True + + if apply_text_encoder is None: + apply_text_encoder = weights_has_text_encoder + else: + assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" + + if apply_unet is None: + apply_unet = weights_has_unet + else: + assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" + else: + assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + print(f"weights are loaded: {info}") + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + params = [] + + if self.text_encoder_loras: + param_data = {'params': enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data['lr'] = text_encoder_lr + params.append(param_data) + + if self.unet_loras: + param_data = {'params': enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data['lr'] = unet_lr + params.append(param_data) + + return params + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype): + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import save_file + save_file(state_dict, file) + else: + torch.save(state_dict, file) diff --git a/networks/merge_lora.py b/networks/merge_lora.py new file mode 100644 index 00000000..d873a8ef --- /dev/null +++ b/networks/merge_lora.py @@ -0,0 +1,159 @@ + + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + + # find original module for this lora + module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + # W <- W + U * D + weight = module.weight + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) + else: + # conv2d + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype): + merged_sd = {} + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + merged_sd[key] = lora_sd[key] * ratio + + return merged_sd + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, + args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") + parser.add_argument("--precision", type=str, default="float", + choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") + parser.add_argument("--sd_model", type=str, default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--models", type=str, nargs='*', + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--ratios", type=float, nargs='*', + help="ratios for each model / それぞれのLoRAモデルの比率") + + args = parser.parse_args() + merge(args) diff --git a/train_network.py b/train_network.py new file mode 100644 index 00000000..9a5942d1 --- /dev/null +++ b/train_network.py @@ -0,0 +1,1452 @@ +import gc +import importlib +import json +import time +from typing import NamedTuple +from torch.autograd.function import Function +import argparse +import glob +import math +import os +import random + +from tqdm import tqdm +import torch +from torchvision import transforms +from accelerate import Accelerator +from accelerate.utils import set_seed +from transformers import CLIPTokenizer +import diffusers +from diffusers import DDPMScheduler, StableDiffusionPipeline +import albumentations as albu +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from torch import einsum + +import library.model_util as model_util + +# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + +# checkpointファイル名 +EPOCH_STATE_NAME = "epoch-{:06d}-state" +LAST_STATE_NAME = "last-state" + +EPOCH_FILE_NAME = "epoch-{:06d}" +LAST_FILE_NAME = "last" + + +# region dataset + +class ImageInfo(): + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: tuple[int, int] = None + self.bucket_reso: tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + batch_index: int + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: + super().__init__() + self.tokenizer: CLIPTokenizer = tokenizer + self.max_token_length = max_token_length + self.shuffle_caption = shuffle_caption + self.shuffle_keep_tokens = shuffle_keep_tokens + self.width, self.height = resolution + self.face_crop_aug_range = face_crop_aug_range + self.flip_aug = flip_aug + self.color_aug = color_aug + self.debug_dataset = debug_dataset + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + # augmentation + flip_p = 0.5 if flip_aug else 0.0 + if color_aug: + # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る + self.aug = albu.Compose([ + albu.OneOf([ + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), + ], p=.33), + albu.HorizontalFlip(p=flip_p) + ], p=1.) + elif flip_aug: + self.aug = albu.Compose([ + albu.HorizontalFlip(p=flip_p) + ], p=1.) + else: + self.aug = None + + self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) + + self.image_data: dict[str, ImageInfo] = {} + + def process_caption(self, caption): + if self.shuffle_caption: + tokens = caption.strip().split(",") + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + else: + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = keep_tokens + tokens + caption = ",".join(tokens).strip() + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer(caption, padding="max_length", truncation=True, + max_length=self.tokenizer_max_length, return_tensors="pt").input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo): + self.image_data[info.image_key] = info + + def make_buckets(self, enable_bucket, min_size, max_size): + ''' + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + ''' + + self.enable_bucket = enable_bucket + + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketingを用意する + if enable_bucket: + bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) + else: + # bucketはひとつだけ、すべての画像は同じ解像度 + bucket_resos = [(self.width, self.height)] + bucket_aspect_ratios = [self.width / self.height] + bucket_aspect_ratios = np.array(bucket_aspect_ratios) + + # bucketを作成する + if enable_bucket: + img_ar_errors = [] + for image_info in self.image_data.values(): + # bucketを決める + image_width, image_height = image_info.image_size + aspect_ratio = image_width / image_height + ar_errors = bucket_aspect_ratios - aspect_ratio + + bucket_id = np.abs(ar_errors).argmin() + image_info.bucket_reso = bucket_resos[bucket_id] + + ar_error = ar_errors[bucket_id] + img_ar_errors.append(ar_error) + else: + reso = (self.width, self.height) + for image_info in self.image_data.values(): + image_info.bucket_reso = reso + + # 画像をbucketに分割する + self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] + reso_to_index = {} + for i, reso in enumerate(bucket_resos): + reso_to_index[reso] = i + + for image_info in self.image_data.values(): + bucket_index = reso_to_index[image_info.bucket_reso] + for _ in range(image_info.num_repeats): + self.buckets[bucket_index].append(image_info.image_key) + + if enable_bucket: + print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数(DreamBoothの場合は繰り返し回数を含む)") + for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + + # 参照用indexを作る + self.buckets_indices: list(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + for bucket in self.buckets: + random.shuffle(bucket) + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def resize_and_trim(self, image, reso): + image_height, image_width = image.shape[0:2] + ar_img = image_width / image_height + ar_reso = reso[0] / reso[1] + if ar_img > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) + + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + elif resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ + f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def cache_latents(self, vae): + print("caching latents.") + for info in tqdm(self.image_data.values()): + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.resize_and_trim(image, info.bucket_reso) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if self.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if self.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + .5) + nw = int(width * scale + .5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + .5) + face_cy = int(face_cy * scale + .5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if self.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1:p1 + target_size, :] + else: + image = image[:, p1:p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + return np.load(npz_file)['arr_0'] + + def __len__(self): + return self._length + + def __getitem__(self, index): + if index == 0: + self.shuffle_buckets() + + bucket = self.buckets[self.buckets_indices[index].bucket_index] + image_index = self.buckets_indices[index].batch_index * self.batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index:image_index + self.batch_size]: + image_info = self.image_data[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.resize_and_trim(img, image_info.bucket_reso) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p:p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p:p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + if self.aug is not None: + img = self.aug(image=img)['image'] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(image_info.caption) + captions.append(caption) + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example['loss_weights'] = torch.FloatTensor(loss_weights) + example['input_ids'] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example['images'] = images + + example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example['image_keys'] = bucket[image_index:image_index + self.batch_size] + example['captions'] = captions + return example + + +class DreamBoothDataset(BaseDataset): + def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.random_crop = random_crop + self.latents_cache = None + self.enable_bucket = False + + def read_caption(img_path): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding='utf-8') as f: + lines = f.readlines() + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(dir): + if not os.path.isdir(dir): + # print(f"ignore file: {dir}") + return 0, [], [] + + tokens = os.path.basename(dir).split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + return 0, [], [] + + caption_by_folder = '_'.join(tokens[1:]) + img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ + glob.glob(os.path.join(dir, "*.webp")) + print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") + + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path) + captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + + return n_repeats, img_paths, captions + + print("prepare train images.") + train_dirs = os.listdir(train_data_dir) + num_train_images = 0 + for dir in train_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) + num_train_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, False, img_path) + self.register_image(info) + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + # reg imageは数を数えて学習画像と同じ枚数にする + num_reg_images = 0 + if reg_data_dir: + print("prepare reg images.") + reg_infos: list[ImageInfo] = [] + + reg_dirs = os.listdir(reg_data_dir) + for dir in reg_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) + num_reg_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, True, img_path) + reg_infos.append(info) + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") + else: + n = 0 + while n < num_train_images: + for info in reg_infos: + self.register_image(info) + n += info.num_repeats + if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない + break + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__(self, metadata, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + + self.metadata = metadata + self.train_data_dir = train_data_dir + self.batch_size = batch_size + + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + # わりといい加減だがいい方法が思いつかん + abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + + glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) + assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" + abs_path = abs_path[0] + + caption = img_md.get('caption') + tags = img_md.get('tags') + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" + + image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) + image_info.image_size = img_md.get('train_resolution') + + if not self.color_aug: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) + + self.register_image(image_info) + self.num_train_images = len(metadata) * dataset_repeats + self.num_reg_images = 0 + + # check existence of all npz files + if not self.color_aug: + npz_any = False + npz_all = True + for image_info in self.image_data.values(): + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if self.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") + elif not npz_all: + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + # check min/max bucket size + sizes = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + + if sizes is None: + self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated + else: + self.min_bucket_reso = min(sizes) + self.max_bucket_reso = max(sizes) + + def image_key_to_npz_file(self, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + '.npz' + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + '_flip.npz' + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # image_key is relative path + npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') + npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + +# endregion + + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(Function): + @ staticmethod + @ torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """ Algorithm 2 in the paper """ + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = (q.shape[-1] ** -0.5) + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, 'b n -> b 1 1 n') + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @ staticmethod + @ torch.no_grad() + def backward(ctx, do): + """ Algorithm 4 in the paper """ + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2) + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.) + + p = exp_attn_weights / lc + + dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) + dp = einsum('... i d, ... j d -> ... i j', doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) + dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use FlashAttention") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, 'b h n d -> b n (h d)') + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format + # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format + del q_in, k_in, v_in + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers +# endregion + + +def collate_fn(examples): + return examples[0] + + +def train(args): + cache_latents = args.cache_latents + + # latentsをキャッシュする場合のオプション設定を確認する + if cache_latents: + assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" + + # その他のオプション設定を確認する + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + use_dreambooth_method = args.in_json is None + + # モデル形式のオプション設定を確認する: + load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) + + # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) + + # tokenizerを読み込む + print("prepare tokenizer") + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + + if args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + # 学習データを用意する + resolution = tuple([int(r) for r in args.resolution.split(',')]) + if len(resolution) == 1: + resolution = (resolution[0], resolution[0]) + assert len(resolution) == 2, \ + f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + + if args.face_crop_aug_range is not None: + face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) + assert len( + face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + face_crop_aug_range = None + + # データセットを準備する + if use_dreambooth_method: + print("Use DreamBooth method.") + train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, + tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, + resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.debug_dataset) + else: + print("Train with captions.") + + # メタデータを読み込む + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + print(f"no metadata / メタデータファイルがありません: {args.in_json}") + return + + if args.color_aug: + print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") + + train_dataset = FineTuningDataset(metadata, args.train_batch_size, args.train_data_dir, + tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, + resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + + if train_dataset.min_bucket_reso is not None and (args.enable_bucket or train_dataset.min_bucket_reso != train_dataset.max_bucket_reso): + print(f"using bucket info in metadata / メタデータ内のbucket情報を使います") + args.min_bucket_reso = train_dataset.min_bucket_reso + args.max_bucket_reso = train_dataset.max_bucket_reso + args.enable_bucket = True + print(f"min bucket reso: {args.min_bucket_reso}, max bucket reso: {args.max_bucket_reso}") + + if args.enable_bucket: + assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + train_dataset.make_buckets(args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso) + + if args.debug_dataset: + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") + k = 0 + for example in train_dataset: + if example['latents'] is not None: + print("sample has latents from npz file") + for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') + if example['images'] is not None: + im = example['images'][j] + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or example['images'] is None: + break + return + + if len(train_dataset) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + # acceleratorを準備する + print("prepare accelerator") + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, + log_with=log_with, logging_dir=logging_dir) + + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + # モデルを読み込む + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) + else: + print("load Diffusers pretrained models") + pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + # モデルに xformers とか memory efficient attention を組み込む + replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # prepare network + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split('=') + net_kwargs[key] = value + + network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) + if network is None: + return + + if args.network_weights is not None: + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) + + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + # 8-bit Adamを使う + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print("use 8-bit Adam optimizer") + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + + # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 + optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + + # lr schedulerを用意する + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + # unet.to(weight_dtype) + # text_encoder.to(weight_dtype) + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_unet and train_text_encoder: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler) + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler) + elif train_text_encoder: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + unet.eval() + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.eval() + + network.prepare_grad_etc(text_encoder, unet) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", + num_train_timesteps=1000, clip_sample=False) + + if accelerator.is_main_process: + accelerator.init_trackers("network_train") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + network.on_epoch_start(text_encoder, unet) + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(network): + with torch.no_grad(): + # latentに変換 + if batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # なぜかこれが必要 + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step+1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"epoch_loss": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch+1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: + print("saving checkpoint.") + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) + unwrap_model(network).save_weights(ckpt_file, save_dtype) + + if args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + + is_main_process = accelerator.is_main_process + if is_main_process: + network = unwrap_model(network) + + accelerator.end_training() + + if args.save_state: + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype) + print("model saved.") + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + parser.add_argument("--network_weights", type=str, default=None, + help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--shuffle_caption", action="store_true", + help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--keep_tokens", type=int, default=None, + help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + parser.add_argument("--in_json", type=str, default=None, help="json meatadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--dataset_repeats", type=int, default=None, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") + parser.add_argument("--output_dir", type=str, default=None, + help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--save_every_n_epochs", type=int, default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument("--face_crop_aug_range", type=str, default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") + parser.add_argument("--random_crop", action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") + parser.add_argument("--debug_dataset", action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") + parser.add_argument("--resolution", type=str, default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--mem_eff_attn", action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") + parser.add_argument("--xformers", action="store_true", + help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--vae", type=str, default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + parser.add_argument("--cache_latents", action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") + parser.add_argument("--enable_bucket", action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") + # parser.add_argument("--stop_text_encoder_training", type=int, default=None, + # help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument("--gradient_checkpointing", action="store_true", + help="enable gradient checkpointing / grandient checkpointingを有効にする") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') + parser.add_argument("--network_dim", type=int, default=None, + help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_args", type=str, default=None, nargs='*', + help='additional argmuments for network (key=value) / ネットワークへの追加の引数') + parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") + parser.add_argument("--network_train_text_encoder_only", action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する") + + args = parser.parse_args() + train(args)