add caption_prefix/suffix to dataset

This commit is contained in:
Kohya S
2023-09-02 16:17:12 +09:00
parent cd59003003
commit 948cf17499
3 changed files with 44 additions and 0 deletions

View File

@@ -138,9 +138,13 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `num_repeats` | `10` | o | o | o |
| `random_crop` | `false` | o | o | o |
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
| `caption_suffix` | `“, from side”` | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
* `caption_prefix`, `caption_suffix`
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
### DreamBooth 方式専用のオプション

View File

@@ -56,6 +56,8 @@ class BaseSubsetParams:
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
random_crop: bool = False
caption_prefix: Optional[str] = None
caption_suffix: Optional[str] = None
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
@@ -159,6 +161,8 @@ class ConfigSanitizer:
"keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
"caption_prefix": str,
"caption_suffix": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -459,6 +463,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}

View File

@@ -340,6 +340,8 @@ class BaseSubset:
caption_dropout_rate: float,
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
caption_prefix: Optional[str],
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None:
@@ -354,6 +356,8 @@ class BaseSubset:
self.caption_dropout_rate = caption_dropout_rate
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.caption_prefix = caption_prefix
self.caption_suffix = caption_suffix
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
@@ -378,6 +382,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
@@ -395,6 +401,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
@@ -426,6 +434,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
@@ -443,6 +453,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
@@ -471,6 +483,8 @@ class ControlNetSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
@@ -488,6 +502,8 @@ class ControlNetSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
@@ -595,6 +611,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements[str_from] = str_to
def process_caption(self, subset: BaseSubset, caption):
# caption に prefix/suffix を付ける
if subset.caption_prefix:
caption = subset.caption_prefix + " " + caption
if subset.caption_suffix:
caption = caption + " " + subset.caption_suffix
# dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
is_drop_out = (
@@ -3068,6 +3090,18 @@ def add_dataset_arguments(
default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残すトークンはカンマ区切りの各部分を意味する",
)
parser.add_argument(
"--caption_prefix",
type=str,
default=None,
help="prefix for caption text / captionのテキストの先頭に付ける文字列",
)
parser.add_argument(
"--caption_suffix",
type=str,
default=None,
help="suffix for caption text / captionのテキストの末尾に付ける文字列",
)
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument(