From 47d61e2c021d52fa9f2f11de395585f3be240ab1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 17 Apr 2023 22:00:26 +0900 Subject: [PATCH] format by black --- finetune/make_captions.py | 259 +++++++------ finetune/make_captions_by_git.py | 224 ++++++----- finetune/prepare_buckets_latents.py | 471 +++++++++++++---------- finetune/tag_images_by_wd14_tagger.py | 133 +++++-- tools/convert_diffusers20_original_sd.py | 174 +++++---- 5 files changed, 719 insertions(+), 542 deletions(-) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 3c78fb48..9e51037f 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -14,158 +14,185 @@ from torchvision.transforms.functional import InterpolationMode from blip.blip import blip_decoder import library.train_util as train_util -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGE_SIZE = 384 # 正方形でいいのか? という気がするがソースがそうなので -IMAGE_TRANSFORM = transforms.Compose([ - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) -]) +IMAGE_TRANSFORM = transforms.Compose( + [ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] +) + # 共通化したいが微妙に処理が異なる…… class ImageLoadingTransformDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] + def __getitem__(self, idx): + img_path = self.images[idx] - try: - image = Image.open(img_path).convert("RGB") - # convert to tensor temporarily so dataloader will accept it - tensor = IMAGE_TRANSFORM(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor = IMAGE_TRANSFORM(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None - return (tensor, img_path) + return (tensor, img_path) def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # fix the seed for reproducibility - seed = args.seed # + utils.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) + # fix the seed for reproducibility + seed = args.seed # + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) - if not os.path.exists("blip"): - args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path + if not os.path.exists("blip"): + args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path - cwd = os.getcwd() - print('Current Working Directory is: ', cwd) - os.chdir('finetune') + cwd = os.getcwd() + print("Current Working Directory is: ", cwd) + os.chdir("finetune") - print(f"load images from {args.train_data_dir}") - train_data_dir_path = Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + print(f"load images from {args.train_data_dir}") + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") - print(f"loading BLIP caption: {args.caption_weights}") - model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") - model.eval() - model = model.to(DEVICE) - print("BLIP loaded") + print(f"loading BLIP caption: {args.caption_weights}") + model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") + model.eval() + model = model.to(DEVICE) + print("BLIP loaded") - # captioningする - def run_batch(path_imgs): - imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) + # captioningする + def run_batch(path_imgs): + imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) - with torch.no_grad(): - if args.beam_search: - captions = model.generate(imgs, sample=False, num_beams=args.num_beams, - max_length=args.max_length, min_length=args.min_length) - else: - captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) + with torch.no_grad(): + if args.beam_search: + captions = model.generate( + imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length + ) + else: + captions = model.generate( + imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length + ) - for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(caption + "\n") - if args.debug: - print(image_path, caption) + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingTransformDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingTransformDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue - img_tensor, image_path = data - if img_tensor is None: - try: - raw_image = Image.open(image_path) - if raw_image.mode != 'RGB': - raw_image = raw_image.convert("RGB") - img_tensor = IMAGE_TRANSFORM(raw_image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue + img_tensor, image_path = data + if img_tensor is None: + try: + raw_image = Image.open(image_path) + if raw_image.mode != "RGB": + raw_image = raw_image.convert("RGB") + img_tensor = IMAGE_TRANSFORM(raw_image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue - b_imgs.append((image_path, img_tensor)) - if len(b_imgs) >= args.batch_size: + b_imgs.append((image_path, img_tensor)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + if len(b_imgs) > 0: run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) - print("done!") + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", - help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--beam_search", action="store_true", - help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") - parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") - parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") - parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") - parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') - parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") - - return parser + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "--caption_weights", + type=str, + default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", + help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument( + "--beam_search", + action="store_true", + help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") + parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") + parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") + parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") + parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() + args = parser.parse_args() - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention - main(args) + main(args) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 41af23dc..ce6e6695 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -12,143 +12,161 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PATTERN_REPLACE = [ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), - re.compile(r'with the number \d+ on (it|\w+ \w+)'), + re.compile(r"with the number \d+ on (it|\w+ \w+)"), re.compile(r'with the words "'), - re.compile(r'word \w+ on it'), - re.compile(r'that says the word \w+ on it'), - re.compile('that says\'the word "( on it)?'), + re.compile(r"word \w+ on it"), + re.compile(r"that says the word \w+ on it"), + re.compile("that says'the word \"( on it)?"), ] # 誤検知しまくりの with the word xxxx を消す def remove_words(captions, debug): - removed_caps = [] - for caption in captions: - cap = caption - for pat in PATTERN_REPLACE: - cap = pat.sub("", cap) - if debug and cap != caption: - print(caption) - print(cap) - removed_caps.append(cap) - return removed_caps + removed_caps = [] + for caption in captions: + cap = caption + for pat in PATTERN_REPLACE: + cap = pat.sub("", cap) + if debug and cap != caption: + print(caption) + print(cap) + removed_caps.append(cap) + return removed_caps def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 - org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation - curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように + # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 + org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation + curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように - # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す - # ここより上で置き換えようとするとすごく大変 - def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): - input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) - if input_ids.size()[0] != curr_batch_size[0]: - input_ids = input_ids.repeat(curr_batch_size[0], 1) - return input_ids - GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す + # ここより上で置き換えようとするとすごく大変 + def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): + input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) + if input_ids.size()[0] != curr_batch_size[0]: + input_ids = input_ids.repeat(curr_batch_size[0], 1) + return input_ids - print(f"load images from {args.train_data_dir}") - train_data_dir_path = Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch - # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") - git_processor = AutoProcessor.from_pretrained(args.model_id) - git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") + print(f"load images from {args.train_data_dir}") + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") - # captioningする - def run_batch(path_imgs): - imgs = [im for _, im in path_imgs] + # できればcacheに依存せず明示的にダウンロードしたい + print(f"loading GIT: {args.model_id}") + git_processor = AutoProcessor.from_pretrained(args.model_id) + git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) + print("GIT loaded") - curr_batch_size[0] = len(path_imgs) - inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 - generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) - captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) + # captioningする + def run_batch(path_imgs): + imgs = [im for _, im in path_imgs] - if args.remove_words: - captions = remove_words(captions, args.debug) + curr_batch_size[0] = len(path_imgs) + inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 + generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) + captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) - for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(caption + "\n") - if args.debug: - print(image_path, caption) + if args.remove_words: + captions = remove_words(captions, args.debug) - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] - image, image_path = data - if image is None: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue - b_imgs.append((image_path, image)) - if len(b_imgs) >= args.batch_size: + image, image_path = data + if image is None: + try: + image = Image.open(image_path) + if image.mode != "RGB": + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, image)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) - - print("done!") + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps", - help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") - parser.add_argument("--remove_words", action="store_true", - help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") - parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") - - return parser + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument( + "--model_id", + type=str, + default="microsoft/git-large-textcaps", + help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID", + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") + parser.add_argument( + "--remove_words", + action="store_true", + help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - main(args) + args = parser.parse_args() + main(args) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index b9c0fa50..fd289d1d 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -14,7 +14,7 @@ from torchvision import transforms import library.model_util as model_util import library.train_util as train_util -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGE_TRANSFORMS = transforms.Compose( [ @@ -25,256 +25,299 @@ IMAGE_TRANSFORMS = transforms.Compose( def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def get_latents(vae, images, weight_dtype): - img_tensors = [IMAGE_TRANSFORMS(image) for image in images] - img_tensors = torch.stack(img_tensors) - img_tensors = img_tensors.to(DEVICE, weight_dtype) - with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() - return latents + img_tensors = [IMAGE_TRANSFORMS(image) for image in images] + img_tensors = torch.stack(img_tensors) + img_tensors = img_tensors.to(DEVICE, weight_dtype) + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() + return latents def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): - if is_full_path: - base_name = os.path.splitext(os.path.basename(image_key))[0] - relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) - else: - base_name = image_key - relative_path = "" + if is_full_path: + base_name = os.path.splitext(os.path.basename(image_key))[0] + relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) + else: + base_name = image_key + relative_path = "" - if flip: - base_name += '_flip' - - if recursive and relative_path: - return os.path.join(data_dir, relative_path, base_name) - else: - return os.path.join(data_dir, base_name) + if flip: + base_name += "_flip" + if recursive and relative_path: + return os.path.join(data_dir, relative_path, base_name) + else: + return os.path.join(data_dir, base_name) def main(args): - # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" - if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" + if args.bucket_reso_steps % 8 > 0: + print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - print(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] + print(f"found {len(image_paths)} images.") - if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") - return - - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - vae = model_util.load_vae(args.model_name_or_path, weight_dtype) - vae.eval() - vae.to(DEVICE, dtype=weight_dtype) - - # bucketのサイズを計算する - max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" - - bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso, - args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps) - if not args.bucket_no_upscale: - bucket_manager.make_buckets() - else: - print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") - - # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する - img_ar_errors = [] - - def process_batch(is_last): - for bucket in bucket_manager.buckets: - if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - latents = get_latents(vae, [img for _, img in bucket], weight_dtype) - assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \ - f"latent shape {latents.shape}, {bucket[0][1].shape}" - - for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) - np.savez(npz_file_name, latent) - - # flip - if args.flip_aug: - latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない - - for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) - np.savez(npz_file_name, latent) - else: - # remove existing flipped npz - for image_key, _ in bucket: - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" - if os.path.isfile(npz_file_name): - print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") - os.remove(npz_file_name) - - bucket.clear() - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] - - bucket_counts = {} - for data_entry in tqdm(data, smoothing=0.0): - if data_entry[0] is None: - continue - - img_tensor, image_path = data_entry[0] - if img_tensor is not None: - image = transforms.functional.to_pil_image(img_tensor) + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding="utf-8") as f: + metadata = json.load(f) else: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue + print(f"no metadata / メタデータファイルがありません: {args.in_json}") + return - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - if image_key not in metadata: - metadata[image_key] = {} + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 + vae = model_util.load_vae(args.model_name_or_path, weight_dtype) + vae.eval() + vae.to(DEVICE, dtype=weight_dtype) - reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) - img_ar_errors.append(abs(ar_error)) - bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 - - # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て - metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) + # bucketのサイズを計算する + max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) + assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + bucket_manager = train_util.BucketManager( + args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps + ) if not args.bucket_no_upscale: - # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する - assert resized_size[0] == reso[0] or resized_size[1] == reso[ - 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" - assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ - 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + bucket_manager.make_buckets() + else: + print( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" + ) - assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ - 1], f"internal error resized size is small: {resized_size}, {reso}" + # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する + img_ar_errors = [] - # 既に存在するファイルがあればshapeを確認して同じならskipする - if args.skip_existing: - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] - if args.flip_aug: - npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz") + def process_batch(is_last): + for bucket in bucket_manager.buckets: + if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: + latents = get_latents(vae, [img for _, img in bucket], weight_dtype) + assert ( + latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8 + ), f"latent shape {latents.shape}, {bucket[0][1].shape}" - found = True - for npz_file in npz_files: - if not os.path.exists(npz_file): - found = False - break + for (image_key, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + np.savez(npz_file_name, latent) - dat = np.load(npz_file)['arr_0'] - if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 - found = False - break - if found: - continue + # flip + if args.flip_aug: + latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない - # 画像をリサイズしてトリミングする - # PILにinter_areaがないのでcv2で…… - image = np.array(image) - if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + for (image_key, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext( + args.train_data_dir, image_key, args.full_path, True, args.recursive + ) + np.savez(npz_file_name, latent) + else: + # remove existing flipped npz + for image_key, _ in bucket: + npz_file_name = ( + get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" + ) + if os.path.isfile(npz_file_name): + print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") + os.remove(npz_file_name) - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] + bucket.clear() - if resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + bucket_counts = {} + for data_entry in tqdm(data, smoothing=0.0): + if data_entry[0] is None: + continue - # # debug - # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) + img_tensor, image_path = data_entry[0] + if img_tensor is not None: + image = transforms.functional.to_pil_image(img_tensor) + else: + try: + image = Image.open(image_path) + if image.mode != "RGB": + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue - # バッチへ追加 - bucket_manager.add_image(reso, (image_key, image)) + image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] + if image_key not in metadata: + metadata[image_key] = {} - # バッチを推論するか判定して推論する - process_batch(False) + # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 - # 残りを処理する - process_batch(True) + reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) + img_ar_errors.append(abs(ar_error)) + bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 - bucket_manager.sort() - for i, reso in enumerate(bucket_manager.resos): - count = bucket_counts.get(reso, 0) - if count > 0: - print(f"bucket {i} {reso}: {count}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") + # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て + metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) - # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding='utf-8') as f: - json.dump(metadata, f, indent=2) - print("done!") + if not args.bucket_no_upscale: + # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する + assert ( + resized_size[0] == reso[0] or resized_size[1] == reso[1] + ), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" + assert ( + resized_size[0] >= reso[0] and resized_size[1] >= reso[1] + ), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + + assert ( + resized_size[0] >= reso[0] and resized_size[1] >= reso[1] + ), f"internal error resized size is small: {resized_size}, {reso}" + + # 既に存在するファイルがあればshapeを確認して同じならskipする + if args.skip_existing: + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] + if args.flip_aug: + npz_files.append( + get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" + ) + + found = True + for npz_file in npz_files: + if not os.path.exists(npz_file): + found = False + break + + dat = np.load(npz_file)["arr_0"] + if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 + found = False + break + if found: + continue + + # 画像をリサイズしてトリミングする + # PILにinter_areaがないのでcv2で…… + image = np.array(image) + if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size // 2 : trim_size // 2 + reso[0]] + + if resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size // 2 : trim_size // 2 + reso[1]] + + assert ( + image.shape[0] == reso[1] and image.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + + # # debug + # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) + + # バッチへ追加 + bucket_manager.add_image(reso, (image_key, image)) + + # バッチを推論するか判定して推論する + process_batch(False) + + # 残りを処理する + process_batch(True) + + bucket_manager.sort() + for i, reso in enumerate(bucket_manager.resos): + count = bucket_counts.get(reso, 0) + if count > 0: + print(f"bucket {i} {reso}: {count}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error: {np.mean(img_ar_errors)}") + + # metadataを書き出して終わり + print(f"writing metadata: {args.out_json}") + with open(args.out_json, "wt", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action='store_true', - help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)') - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--max_resolution", type=str, default="512,512", - help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") - parser.add_argument("--bucket_reso_steps", type=int, default=64, - help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") - parser.add_argument("--bucket_no_upscale", action="store_true", - help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--flip_aug", action="store_true", - help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") - parser.add_argument("--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") + parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument( + "--max_resolution", + type=str, + default="512,512", + help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)", + ) + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", + ) + parser.add_argument( + "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + ) + parser.add_argument( + "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - main(args) + args = parser.parse_args() + main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index fb7ed7b4..40bf428c 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -18,15 +18,16 @@ import library.train_util as train_util IMAGE_SIZE = 448 # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 -DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' +DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] + def preprocess_image(image): image = np.array(image) - image = image[:, :, ::-1] # RGB->BGR + image = image[:, :, ::-1] # RGB->BGR # pad to square size = max(image.shape[0:2]) @@ -34,7 +35,7 @@ def preprocess_image(image): pad_y = size - image.shape[0] pad_l = pad_x // 2 pad_t = pad_y // 2 - image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) @@ -42,6 +43,7 @@ def preprocess_image(image): image = image.astype(np.float32) return image + class ImageLoadingPrepDataset(torch.utils.data.Dataset): def __init__(self, image_paths): self.images = image_paths @@ -61,7 +63,8 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset): return None return (tensor, img_path) - + + def collate_fn_remove_corrupted(batch): """Collate function that allows to remove corrupted examples in the dataloader. It expects that the dataloader returns 'None' when that occurs. @@ -71,6 +74,7 @@ def collate_fn_remove_corrupted(batch): batch = list(filter(lambda x: x is not None, batch)) return batch + def main(args): # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 @@ -80,8 +84,14 @@ def main(args): for file in FILES: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: - hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( - args.model_dir, SUB_DIR), force_download=True, force_filename=file) + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(args.model_dir, SUB_DIR), + force_download=True, + force_filename=file, + ) else: print("using existing wd14 tagger model") @@ -94,22 +104,22 @@ def main(args): with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) l = [row for row in reader] - header = l[0] # tag_id,name,category,count + header = l[0] # tag_id,name,category,count rows = l[1:] - assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" + assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - general_tags = [row[1] for row in rows[1:] if row[2] == '0'] - character_tags = [row[1] for row in rows[1:] if row[2] == '4'] + general_tags = [row[1] for row in rows[1:] if row[2] == "0"] + character_tags = [row[1] for row in rows[1:] if row[2] == "4"] # 画像を読み込む - + train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") tag_freq = {} - undesired_tags = set(args.undesired_tags.split(',')) + undesired_tags = set(args.undesired_tags.split(",")) def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) @@ -131,13 +141,17 @@ def main(args): character_tag_text = "" for i, p in enumerate(prob[4:]): if i < len(general_tags) and p >= args.general_threshold: - tag_name = general_tags[i].replace('_', ' ') if args.remove_underscore else general_tags[i] + tag_name = general_tags[i].replace("_", " ") if args.remove_underscore else general_tags[i] if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 general_tag_text += ", " + tag_name combined_tags.append(tag_name) elif i >= len(general_tags) and p >= args.character_threshold: - tag_name = character_tags[i - len(general_tags)].replace('_', ' ') if args.remove_underscore else character_tags[i - len(general_tags)] + tag_name = ( + character_tags[i - len(general_tags)].replace("_", " ") + if args.remove_underscore + else character_tags[i - len(general_tags)] + ) if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += ", " + tag_name @@ -149,19 +163,24 @@ def main(args): if len(character_tag_text) > 0: character_tag_text = character_tag_text[2:] - tag_text = ', '.join(combined_tags) - - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(tag_text + '\n') + tag_text = ", ".join(combined_tags) + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + f.write(tag_text + "\n") if args.debug: print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") - # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: dataset = ImageLoadingPrepDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + data = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) else: data = [[(None, ip)] for ip in image_paths] @@ -177,7 +196,7 @@ def main(args): else: try: image = Image.open(image_path) - if image.mode != 'RGB': + if image.mode != "RGB": image = image.convert("RGB") image = preprocess_image(image) except Exception as e: @@ -203,36 +222,72 @@ def main(args): print("done!") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, - help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") - parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", - help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") - parser.add_argument("--force_download", action='store_true', - help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") + parser.add_argument( + "--repo_id", + type=str, + default=DEFAULT_WD14_TAGGER_REPO, + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID", + ) + parser.add_argument( + "--model_dir", + type=str, + default="wd14_tagger_model", + help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", + ) + parser.add_argument( + "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") - parser.add_argument("--general_threshold", type=float, default=None, help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ") - parser.add_argument("--character_threshold", type=float, default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ") + parser.add_argument( + "--general_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", + ) + parser.add_argument( + "--character_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", + ) parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") - parser.add_argument("--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える") + parser.add_argument( + "--remove_underscore", + action="store_true", + help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", + ) parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument("--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト") - parser.add_argument('--frequency_tags', action='store_true', help='Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する') + parser.add_argument( + "--undesired_tags", + type=str, + default="", + help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", + ) + parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") args = parser.parse_args() # スペルミスしていたオプションを復元する if args.caption_extention is not None: args.caption_extension = args.caption_extention - + if args.general_threshold is None: args.general_threshold = args.thresh if args.character_threshold is None: diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 51d5f362..15a9ca4a 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -9,88 +9,122 @@ import library.model_util as model_util def convert(args): - # 引数を確認する - load_dtype = torch.float16 if args.fp16 else None + # 引数を確認する + load_dtype = torch.float16 if args.fp16 else None - save_dtype = None - if args.fp16 or args.save_precision_as == "fp16": - save_dtype = torch.float16 - elif args.bf16 or args.save_precision_as == "bf16": - save_dtype = torch.bfloat16 - elif args.float or args.save_precision_as == "float": - save_dtype = torch.float + save_dtype = None + if args.fp16 or args.save_precision_as == "fp16": + save_dtype = torch.float16 + elif args.bf16 or args.save_precision_as == "bf16": + save_dtype = torch.bfloat16 + elif args.float or args.save_precision_as == "float": + save_dtype = torch.float - is_load_ckpt = os.path.isfile(args.model_to_load) - is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 + is_load_ckpt = os.path.isfile(args.model_to_load) + is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 - assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" - assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" + assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" + assert ( + is_save_ckpt or args.reference_model is not None + ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" - # モデルを読み込む - msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") + # モデルを読み込む + msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) + print(f"loading {msg}: {args.model_to_load}") - if is_load_ckpt: - v2_model = args.v2 - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) - else: - pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - - if args.v1 == args.v2: - # 自動判定する - v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ('v2' if v2_model else 'v1')) + if is_load_ckpt: + v2_model = args.v2 + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) else: - v2_model = not args.v1 + pipe = StableDiffusionPipeline.from_pretrained( + args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None + ) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet - # 変換して保存する - msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") + if args.v1 == args.v2: + # 自動判定する + v2_model = unet.config.cross_attention_dim == 1024 + print("checking model version: model is " + ("v2" if v2_model else "v1")) + else: + v2_model = not args.v1 - if is_save_ckpt: - original_model = args.model_to_load if is_load_ckpt else None - key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, - original_model, args.epoch, args.global_step, save_dtype, vae) - print(f"model saved. total converted state_dict keys: {key_count}") - else: - print(f"copy scheduler/tokenizer config from: {args.reference_model}") - model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) - print(f"model saved.") + # 変換して保存する + msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" + print(f"converting and saving as {msg}: {args.model_to_save}") + + if is_save_ckpt: + original_model = args.model_to_load if is_load_ckpt else None + key_count = model_util.save_stable_diffusion_checkpoint( + v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae + ) + print(f"model saved. total converted state_dict keys: {key_count}") + else: + print(f"copy scheduler/tokenizer config from: {args.reference_model}") + model_util.save_diffusers_checkpoint( + v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors + ) + print(f"model saved.") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v1", action='store_true', - help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') - parser.add_argument("--v2", action='store_true', - help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') - parser.add_argument("--fp16", action='store_true', - help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') - parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') - parser.add_argument("--float", action='store_true', - help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') - parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"], - help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください") - parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') - parser.add_argument("--global_step", type=int, default=0, - help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') - parser.add_argument("--reference_model", type=str, default=None, - help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)") + parser = argparse.ArgumentParser() + parser.add_argument( + "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む" + ) + parser.add_argument( + "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" + ) + parser.add_argument( + "--fp16", + action="store_true", + help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)", + ) + parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)") + parser.add_argument( + "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)" + ) + parser.add_argument( + "--save_precision_as", + type=str, + default="no", + choices=["fp16", "bf16", "float"], + help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください", + ) + parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値") + parser.add_argument( + "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" + ) + parser.add_argument( + "--reference_model", + type=str, + default=None, + help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)", + ) - parser.add_argument("model_to_load", type=str, default=None, - help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") - parser.add_argument("model_to_save", type=str, default=None, - help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") - return parser + parser.add_argument( + "model_to_load", + type=str, + default=None, + help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ", + ) + parser.add_argument( + "model_to_save", + type=str, + default=None, + help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存", + ) + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - convert(args) + args = parser.parse_args() + convert(args)