diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml new file mode 100644 index 00000000..e3783839 --- /dev/null +++ b/.github/workflows/typos.yml @@ -0,0 +1,21 @@ +--- +# yamllint disable rule:line-length +name: Typos + +on: # yamllint disable-line rule:truthy + push: + pull_request: + types: + - opened + - synchronize + - reopened + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: typos-action + uses: crate-ci/typos@v1.13.10 diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 00000000..396ee5c5 --- /dev/null +++ b/_typos.toml @@ -0,0 +1,15 @@ +# Files for typos +# Instruction: https://github.com/marketplace/actions/typos-action#getting-started + +[default.extend-identifiers] + +[default.extend-words] +NIN="NIN" +parms="parms" +nin="nin" +extention="extention" # Intentionally left +nd="nd" + + +[files] +extend-exclude = ["_typos.toml"] diff --git a/fine_tune.py b/fine_tune.py index 6a95886c..52921530 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -36,6 +36,10 @@ def train(args): args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) + train_dataset.make_buckets() if args.debug_dataset: @@ -226,6 +230,8 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) + for m in training_models: m.train() @@ -332,7 +338,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) diff --git a/library/train_util.py b/library/train_util.py index 6f809deb..df6e24e2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -113,7 +113,7 @@ class BucketManager(): # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく self.predefined_resos = resos.copy() self.predefined_resos_set = set(resos) - self.predifined_aspect_ratios = np.array([w / h for w, h in resos]) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) def add_if_new_reso(self, reso): if reso not in self.reso_to_id: @@ -135,7 +135,7 @@ class BucketManager(): if reso in self.predefined_resos_set: pass else: - ar_errors = self.predifined_aspect_ratios - aspect_ratio + ar_errors = self.predefined_aspect_ratios - aspect_ratio predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの reso = self.predefined_resos[predefined_bucket_id] @@ -223,6 +223,10 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.dropout_rate: float = 0 + self.dropout_every_n_epochs: int = None + # augmentation flip_p = 0.5 if flip_aug else 0.0 if color_aug: @@ -247,6 +251,15 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate): + # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) + self.dropout_rate = dropout_rate + self.dropout_every_n_epochs = dropout_every_n_epochs + self.tag_dropout_rate = tag_dropout_rate + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -264,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements[str_from] = str_to def process_caption(self, caption): - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate + is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) + if is_drop_out: + caption = "" + else: + if self.shuffle_caption: + def dropout_tags(tokens): + if self.tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= self.tag_dropout_rate: + l.append(token) + return l + + tokens = [t.strip() for t in caption.strip().split(",")] + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + tokens = dropout_tags(tokens) else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = dropout_tags(tokens) + + tokens = keep_tokens + tokens + caption = ", ".join(tokens) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) return caption @@ -907,6 +940,8 @@ class FineTuningDataset(BaseDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") + + train_dataset.set_current_epoch(1) k = 0 for i, example in enumerate(train_dataset): if example['latents'] is not None: @@ -1377,7 +1412,7 @@ def verify_training_args(args: argparse.Namespace): print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--shuffle_caption", action="store_true", @@ -1408,6 +1443,16 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b parser.add_argument("--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument("--caption_dropout_rate", type=float, default=0, + help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") + parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") + parser.add_argument("--caption_tag_dropout_rate", type=float, default=0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") + if support_dreambooth: # DreamBooth dataset parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") diff --git a/networks/lora.py b/networks/lora.py index 174feda5..a1f38c16 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,6 +5,7 @@ import math import os +from typing import List import torch from library import train_util @@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module): self.alpha = alpha # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: loras = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py new file mode 100644 index 00000000..0876a4d3 --- /dev/null +++ b/tools/resize_images_to_resolution.py @@ -0,0 +1,113 @@ +import glob +import os +import cv2 +import argparse +import shutil +import math + + +def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): + # Split the max_resolution string by "," and strip any whitespaces + max_resolutions = [res.strip() for res in max_resolution.split(',')] + + # # Calculate max_pixels from max_resolution string + # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Create destination folder if it does not exist + if not os.path.exists(dst_img_folder): + os.makedirs(dst_img_folder) + + # Select interpolation method + if interpolation == 'lanczos4': + cv2_interpolation = cv2.INTER_LANCZOS4 + elif interpolation == 'cubic': + cv2_interpolation = cv2.INTER_CUBIC + else: + cv2_interpolation = cv2.INTER_AREA + + # Iterate through all files in src_img_folder + img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py + for filename in os.listdir(src_img_folder): + # Check if the image is png, jpg or webp etc... + if not filename.endswith(img_exts): + # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) + shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) + continue + + # Load image + img = cv2.imread(os.path.join(src_img_folder, filename)) + + base, _ = os.path.splitext(filename) + for max_resolution in max_resolutions: + # Calculate max_pixels from max_resolution string + max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Calculate current number of pixels + current_pixels = img.shape[0] * img.shape[1] + + # Check if the image needs resizing + if current_pixels > max_pixels: + # Calculate scaling factor + scale_factor = max_pixels / current_pixels + + # Calculate new dimensions + new_height = int(img.shape[0] * math.sqrt(scale_factor)) + new_width = int(img.shape[1] * math.sqrt(scale_factor)) + + # Resize image + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + new_height, new_width = img.shape[0:2] + + # Calculate the new height and width that are divisible by divisible_by (with/without resizing) + new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by + new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by + + # Center crop the image to the calculated dimensions + y = int((img.shape[0] - new_height) / 2) + x = int((img.shape[1] - new_width) / 2) + img = img[y:y + new_height, x:x + new_width] + + # Split filename into base and extension + new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') + + # Save resized image in dst_img_folder + cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) + proc = "Resized" if current_pixels > max_pixels else "Saved" + print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + + # If other files with same basename, copy them with resolution suffix + if copy_associated_files: + asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) + for asoc_file in asoc_files: + ext = os.path.splitext(asoc_file)[1] + if ext in img_exts: + continue + for max_resolution in max_resolutions: + new_asoc_file = base + '+' + max_resolution + ext + print(f"Copy {asoc_file} as {new_asoc_file}") + shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) + + +def main(): + parser = argparse.ArgumentParser( + description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') + parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') + parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') + parser.add_argument('--max_resolution', type=str, + help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") + parser.add_argument('--divisible_by', type=int, + help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) + parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], + default='area', help='Interpolation method for resizing / リサイズ時の補完方法') + parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') + parser.add_argument('--copy_associated_files', action='store_true', + help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') + + args = parser.parse_args() + resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, + args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) + + +if __name__ == '__main__': + main() diff --git a/train_db.py b/train_db.py index d1bbc07f..c210767b 100644 --- a/train_db.py +++ b/train_db.py @@ -38,8 +38,13 @@ def train(args): args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps, args.bucket_no_upscale, args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: train_dataset.disable_token_padding() + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) + train_dataset.make_buckets() if args.debug_dataset: @@ -203,6 +208,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -327,7 +333,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False) + train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) diff --git a/train_network.py b/train_network.py index 3e8f4e7d..bb3159fd 100644 --- a/train_network.py +++ b/train_network.py @@ -120,18 +120,22 @@ def train(args): print("Use DreamBooth method.") train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.bucket_reso_steps, args.bucket_no_upscale, + args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) else: print("Train with captions.") train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, + args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) + train_dataset.make_buckets() if args.debug_dataset: @@ -376,6 +380,8 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -509,7 +515,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 794cdd0f..e0ebaf76 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py --network_module=networks.lora ``` ---output_dirオプションで指定したディレクトリに、LoRAのモデルが保存されます。 +--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。 その他、以下のオプションが指定できます。 @@ -178,6 +178,38 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA - --save_precision - LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。 +## 画像リサイズスクリプト + +(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。) + +Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。 + +### スクリプトの実行方法 + +以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。 + +``` +python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png + --copy_associated_files 元画像フォルダ 変換先フォルダ +``` + +元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。 + +``--max_resolution`` オプションにリサイズ先のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。 + +``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式(quality=100)で保存されます。 + +``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。 + + +### その他のオプション + +- divisible_by + - リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。 +- interpolation + - 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。 + + ## 追加情報 ### cloneofsimo氏のリポジトリとの違い diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 7a8370cd..ba2e7145 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -235,7 +235,7 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler) index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - print(len(index_no_updates), torch.sum(index_no_updates)) + # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -296,6 +296,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) text_encoder.train() @@ -383,8 +384,8 @@ def train(args): accelerator.wait_for_everyone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - d = updated_embs - bef_epo_embs - print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) + # d = updated_embs - bef_epo_embs + # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) if args.save_every_n_epochs is not None: model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name @@ -478,7 +479,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],