diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 7f2b6c4c..69a03f6c 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -138,9 +138,13 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `num_repeats` | `10` | o | o | o | | `random_crop` | `false` | o | o | o | | `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | +| `caption_suffix` | `“, from side”` | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 +* `caption_prefix`, `caption_suffix` + * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。 ### DreamBooth 方式専用のオプション diff --git a/library/config_util.py b/library/config_util.py index 3604ea57..813483e7 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -56,6 +56,8 @@ class BaseSubsetParams: flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 @@ -159,6 +161,8 @@ class ConfigSanitizer: "keep_tokens": int, "token_warmup_min": int, "token_warmup_step": Any(float,int), + "caption_prefix": str, + "caption_suffix": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -459,6 +463,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} color_aug: {subset.color_aug} flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} diff --git a/library/train_util.py b/library/train_util.py index e1466179..4ae9201d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -340,6 +340,8 @@ class BaseSubset: caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + caption_prefix: Optional[str], + caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], ) -> None: @@ -354,6 +356,8 @@ class BaseSubset: self.caption_dropout_rate = caption_dropout_rate self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.caption_prefix = caption_prefix + self.caption_suffix = caption_suffix self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる @@ -378,6 +382,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) -> None: @@ -395,6 +401,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) @@ -426,6 +434,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) -> None: @@ -443,6 +453,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) @@ -471,6 +483,8 @@ class ControlNetSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) -> None: @@ -488,6 +502,8 @@ class ControlNetSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) @@ -595,6 +611,12 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements[str_from] = str_to def process_caption(self, subset: BaseSubset, caption): + # caption に prefix/suffix を付ける + if subset.caption_prefix: + caption = subset.caption_prefix + " " + caption + if subset.caption_suffix: + caption = caption + " " + subset.caption_suffix + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate is_drop_out = ( @@ -3068,6 +3090,18 @@ def add_dataset_arguments( default=0, help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", ) + parser.add_argument( + "--caption_prefix", + type=str, + default=None, + help="prefix for caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--caption_suffix", + type=str, + default=None, + help="suffix for caption text / captionのテキストの末尾に付ける文字列", + ) parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") parser.add_argument(