diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 6b55a985..1a63d8ff 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -153,6 +153,7 @@ These options are related to subset configuration. | `keep_tokens_separator` | `“|||”` | o | o | o | | `secondary_separator` | `“;;;”` | o | o | o | | `enable_wildcard` | `true` | o | o | o | +| `enable_multiline_captions` | `true` | o | o | o | | `resize_interpolation` | (not specified) | o | o | o | * `num_repeats` @@ -167,6 +168,8 @@ These options are related to subset configuration. * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. * `enable_wildcard` * Enables wildcard notation. This will be explained later. +* `enable_multiline_captions` + * Enables multi-line captions. This will be explained later. * `resize_interpolation` * Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used. @@ -332,9 +335,22 @@ As a temporary measure, we will list common errors and their solutions. If you e ## Miscellaneous -### Multi-line captions +### Use multi-line captions as they are -By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption. +By setting `enable_multiline_captions = true`, if the caption file consists of multiple lines, it will be used as is (including line breaks). + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +detailed digital art of a girl Hatsune Miku with a microphone on a stage +``` + +(The default is `false`, and in that case, the first line is used as the caption, and the rest are ignored.) + +### Wildcard notation and multi-line captions + +By setting `enable_wildcard = true`, if the caption file consists of multiple lines, one line is randomly selected as the caption. + +NOTE: When `enable_wildcard` is enabled, `enable_multiline_captions` is ignored, and the caption file is always treated as a random selection of one line. ```txt 1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 61d3e251..d321bff4 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -148,6 +148,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `keep_tokens_separator` | `“|||”` | o | o | o | | `secondary_separator` | `“;;;”` | o | o | o | | `enable_wildcard` | `true` | o | o | o | +| `enable_multiline_captions` | `true` | o | o | o | | `resize_interpolation` |(通常は設定しません) | o | o | o | * `num_repeats` @@ -165,7 +166,10 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。 * `enable_wildcard` - * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 + * ワイルドカード記法を有効にします。ワイルドカード記法、複数行キャプションについては後述します。 + +* `enable_multiline_captions` + * 複数行キャプションを有効にします。キャプションファイルが複数の行からなる場合でも、そのまま利用されます。 * `resize_interpolation` * 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。 @@ -341,7 +345,15 @@ skip_image_resolution = 1024 ### 複数行キャプション -`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。 +`enable_multiline_captions = true` を設定することで、キャプションファイルが複数の行からなる場合でも、そのまま利用されるようになります。 + +(未指定時のデフォルト動作は、キャプションファイルの最初の行のみが利用されます。) + +### ワイルドカード記法 + +`enable_wildcard = true` を設定することで、キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。 + +※`enable_wildcard`が指定されると、`enable_multiline_captions` は無視されます(複数行からランダムに一行選ぶ挙動が優先される)。 ```txt 1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage diff --git a/library/config_util.py b/library/config_util.py index b31f9665..2c885318 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -61,6 +61,7 @@ class BaseSubsetParams: keep_tokens_separator: str = (None,) secondary_separator: Optional[str] = None enable_wildcard: bool = False + enable_multiline_captions: bool = False color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -110,6 +111,7 @@ class BaseDatasetParams: resize_interpolation: Optional[str] = None skip_image_resolution: Optional[Tuple[int, int]] = None + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -120,6 +122,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -193,6 +196,7 @@ class ConfigSanitizer: "secondary_separator": str, "caption_separator": str, "enable_wildcard": bool, + "enable_multiline_captions": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), "caption_prefix": str, @@ -473,7 +477,10 @@ class BlueprintGenerator: return default_value -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: + +def generate_dataset_group_by_blueprint( + dataset_group_blueprint: DatasetGroupBlueprint, +) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: @@ -498,7 +505,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: - logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") + logging.warning( + f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split..." + ) continue # if the dataset isn't setting a validation split, there is no current validation dataset @@ -527,27 +536,36 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu for i, dataset in enumerate(_datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ + info += dedent( + f"""\ [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} skip_image_resolution: {dataset.skip_image_resolution} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} - """) + """ + ) if dataset.enable_bucket: - info += indent(dedent(f"""\ + info += indent( + dedent( + f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + \n""" + ), + " ", + ) else: info += "\n" for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ + info += indent( + dedent( + f"""\ [Subset {j} of {dataset_type} {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} @@ -559,6 +577,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} + enable_wildcard: {subset.enable_wildcard} + enable_multiline_captions: {subset.enable_multiline_captions} color_aug: {subset.color_aug} flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} @@ -568,18 +588,31 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu alpha_mask: {subset.alpha_mask} resize_interpolation: {subset.resize_interpolation} custom_attributes: {subset.custom_attributes} - """), " ") + """ + ), + " ", + ) if is_dreambooth: - info += indent(dedent(f"""\ + info += indent( + dedent( + f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} - \n"""), " ") + \n""" + ), + " ", + ) elif not is_controlnet: - info += indent(dedent(f"""\ + info += indent( + dedent( + f"""\ metadata_file: {subset.metadata_file} - \n"""), " ") + \n""" + ), + " ", + ) logger.info(info) @@ -602,10 +635,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return ( - DatasetGroup(datasets), - DatasetGroup(val_datasets) if val_datasets else None - ) + return (DatasetGroup(datasets), DatasetGroup(val_datasets) if val_datasets else None) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 83d04f5e..733fd094 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -421,6 +421,7 @@ class BaseSubset: keep_tokens_separator: str, secondary_separator: Optional[str], enable_wildcard: bool, + enable_multiline_captions: bool, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -446,6 +447,7 @@ class BaseSubset: self.keep_tokens_separator = keep_tokens_separator self.secondary_separator = secondary_separator self.enable_wildcard = enable_wildcard + self.enable_multiline_captions = enable_multiline_captions self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -485,6 +487,7 @@ class DreamBoothSubset(BaseSubset): keep_tokens_separator, secondary_separator, enable_wildcard, + enable_multiline_captions, color_aug, flip_aug, face_crop_aug_range, @@ -513,6 +516,7 @@ class DreamBoothSubset(BaseSubset): keep_tokens_separator, secondary_separator, enable_wildcard, + enable_multiline_captions, color_aug, flip_aug, face_crop_aug_range, @@ -556,6 +560,7 @@ class FineTuningSubset(BaseSubset): keep_tokens_separator, secondary_separator, enable_wildcard, + enable_multiline_captions, color_aug, flip_aug, face_crop_aug_range, @@ -584,6 +589,7 @@ class FineTuningSubset(BaseSubset): keep_tokens_separator, secondary_separator, enable_wildcard, + enable_multiline_captions, color_aug, flip_aug, face_crop_aug_range, @@ -623,6 +629,7 @@ class ControlNetSubset(BaseSubset): keep_tokens_separator, secondary_separator, enable_wildcard, + enable_multiline_captions, color_aug, flip_aug, face_crop_aug_range, @@ -651,6 +658,7 @@ class ControlNetSubset(BaseSubset): keep_tokens_separator, secondary_separator, enable_wildcard, + enable_multiline_captions, color_aug, flip_aug, face_crop_aug_range, @@ -863,6 +871,9 @@ class BaseDataset(torch.utils.data.Dataset): # unescape the curly braces caption = caption.replace(replacer1, "{").replace(replacer2, "}") + elif subset.enable_multiline_captions: + # use multiline captions as they are + pass else: # if caption is multiline, use the first line caption = caption.split("\n")[0] @@ -1954,7 +1965,7 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension, enable_wildcard): + def read_caption(img_path, caption_extension, enable_wildcard, enable_multiline_captions): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name @@ -1975,6 +1986,8 @@ class DreamBoothDataset(BaseDataset): assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" if enable_wildcard: caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + elif enable_multiline_captions: + caption = "".join(lines).strip() # 改行も含めて連結 else: caption = lines[0].strip() break @@ -2095,7 +2108,9 @@ class DreamBoothDataset(BaseDataset): captions = [] missing_captions = [] for img_path in tqdm(img_paths, desc="read caption"): - cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + cap_for_img = read_caption( + img_path, subset.caption_extension, subset.enable_wildcard, subset.enable_multiline_captions + ) if cap_for_img is None and subset.class_tokens is None: logger.warning( f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" @@ -6260,7 +6275,8 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names): ) if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]: logs["lr/d*eff_lr/" + name] = ( - lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"] + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] + * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"] )