diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 8f53737d..11a59b1f 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -5,13 +5,32 @@ import argparse import glob import os import json +import re from tqdm import tqdm +PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') +PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') +PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') +PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') + +# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する +PATTERNS_REMOVE_IN_MULTI = [ + PATTERN_HAIR_LENGTH, + PATTERN_HAIR_CUT, + re.compile(r', [\w\-]+ eyes, '), + re.compile(r', ([\w\-]+ sleeves|sleeveless), '), + # 複数の髪型定義がある場合は削除する + re.compile( + r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), +] + def clean_tags(image_key, tags): # replace '_' to ' ' + tags = tags.replace('^_^', '^@@@^') tags = tags.replace('_', ' ') + tags = tags.replace('^@@@^', '^_^') # remove rating: deepdanbooruのみ tokens = tags.split(", rating") @@ -26,6 +45,37 @@ def clean_tags(image_key, tags): print(f"{image_key} {tags}") tags = tokens[0] + tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 + + # 複数の人物がいる場合は髪色等のタグを削除する + if 'girls' in tags or 'boys' in tags: + for pat in PATTERNS_REMOVE_IN_MULTI: + found = pat.findall(tags) + if len(found) > 1: # 二つ以上、タグがある + tags = pat.sub("", tags) + + # 髪の特殊対応 + srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) + if srch_hair_len: + org = srch_hair_len.group() + tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) + + found = PATTERN_HAIR.findall(tags) + if len(found) > 1: + tags = PATTERN_HAIR.sub("", tags) + + if srch_hair_len: + tags = tags.replace(", @@@, ", org) # 戻す + + # white shirtとshirtみたいな重複タグの削除 + found = PATTERN_WORD.findall(tags) + for word in found: + if re.search(f", ((\w+) )+{word}, ", tags): + tags = tags.replace(f", {word}, ", "") + + tags = tags.replace(", , ", ", ") + assert tags.startswith(", ") and tags.endswith(", ") + tags = tags[2:-2] return tags @@ -88,13 +138,23 @@ def main(args): if tags is None: print(f"image does not have tags / メタデータにタグがありません: {image_key}") else: - metadata[image_key]['tags'] = clean_tags(image_key, tags) + org = tags + tags = clean_tags(image_key, tags) + metadata[image_key]['tags'] = tags + if args.debug and org != tags: + print("FROM: " + org) + print("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: - metadata[image_key]['caption'] = clean_caption(caption) + org = caption + caption = clean_caption(caption) + metadata[image_key]['caption'] = caption + if args.debug and org != caption: + print("FROM: " + org) + print("TO: " + caption) # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") @@ -108,6 +168,7 @@ if __name__ == '__main__': # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument("--debug", action="store_true", help="debug mode") args, unknown = parser.parse_known_args() if len(unknown) == 1: