diff --git a/fine_tune.py b/fine_tune.py index 17b89852..e743a349 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -36,6 +36,10 @@ def train(args): args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.make_buckets() if args.debug_dataset: @@ -171,10 +175,6 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - # 学習データのdropout率を設定する - train_dataset.dropout_rate = args.dropout_rate - train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs - # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) @@ -339,7 +339,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) diff --git a/library/train_util.py b/library/train_util.py index 60da9143..612eba2d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -113,7 +113,7 @@ class BucketManager(): # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく self.predefined_resos = resos.copy() self.predefined_resos_set = set(resos) - self.predifined_aspect_ratios = np.array([w / h for w, h in resos]) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) def add_if_new_reso(self, reso): if reso not in self.reso_to_id: @@ -135,7 +135,7 @@ class BucketManager(): if reso in self.predefined_resos_set: pass else: - ar_errors = self.predifined_aspect_ratios - aspect_ratio + ar_errors = self.predefined_aspect_ratios - aspect_ratio predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの reso = self.predefined_resos[predefined_bucket_id] @@ -223,9 +223,10 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - self.epoch_current:int = int(0) - self.dropout_rate:float = 0 - self.dropout_every_n_epochs:int = 0 + # TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう + self.epoch_current: int = int(0) + self.dropout_rate: float = 0 + self.dropout_every_n_epochs: int = None # augmentation flip_p = 0.5 if flip_aug else 0.0 @@ -251,6 +252,12 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs): + # 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく + # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) + self.dropout_rate = dropout_rate + self.dropout_every_n_epochs = dropout_every_n_epochs + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -604,9 +611,9 @@ class BaseDataset(torch.utils.data.Dataset): # dropoutの決定 is_drop_out = False - if self.dropout_rate > 0 and self.dropout_rate < random.random() : + if self.dropout_rate > 0 and random.random() < self.dropout_rate: is_drop_out = True - if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 : + if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0: is_drop_out = True if is_drop_out: @@ -1391,7 +1398,7 @@ def verify_training_args(args: argparse.Namespace): print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--shuffle_caption", action="store_true", @@ -1421,10 +1428,14 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") parser.add_argument("--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") - parser.add_argument("--dropout_rate", type=float, default=0, - help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") - parser.add_argument("--dropout_every_n_epochs", type=int, default=0, - help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") + + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument("--caption_dropout_rate", type=float, default=0, + help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") + parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") if support_dreambooth: # DreamBooth dataset diff --git a/train_db.py b/train_db.py index 96a4dde6..51f5038b 100644 --- a/train_db.py +++ b/train_db.py @@ -38,8 +38,13 @@ def train(args): args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps, args.bucket_no_upscale, args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: train_dataset.disable_token_padding() + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.make_buckets() if args.debug_dataset: @@ -136,10 +141,6 @@ def train(args): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - # 学習データのdropout率を設定する - train_dataset.dropout_rate = args.dropout_rate - train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs - # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * len(train_dataloader) @@ -333,7 +334,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False) + train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) diff --git a/train_network.py b/train_network.py index 82ebeaf1..f3ca417c 100644 --- a/train_network.py +++ b/train_network.py @@ -132,6 +132,10 @@ def train(args): args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) + + # 学習データのdropout率を設定する + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.make_buckets() if args.debug_dataset: @@ -219,10 +223,6 @@ def train(args): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - # 学習データのdropout率を設定する - train_dataset.dropout_rate = args.dropout_rate - train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs - # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * len(train_dataloader) @@ -516,7 +516,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 7a8370cd..d3e558a3 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -478,7 +478,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True) + train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],