mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
feat: add support for using multi-line captions as they are and update related configurations
This commit is contained in:
@@ -153,6 +153,7 @@ These options are related to subset configuration.
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `enable_multiline_captions` | `true` | o | o | o |
|
||||
| `resize_interpolation` | (not specified) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
@@ -167,6 +168,8 @@ These options are related to subset configuration.
|
||||
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
|
||||
* `enable_wildcard`
|
||||
* Enables wildcard notation. This will be explained later.
|
||||
* `enable_multiline_captions`
|
||||
* Enables multi-line captions. This will be explained later.
|
||||
* `resize_interpolation`
|
||||
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
|
||||
|
||||
@@ -332,9 +335,22 @@ As a temporary measure, we will list common errors and their solutions. If you e
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
### Multi-line captions
|
||||
### Use multi-line captions as they are
|
||||
|
||||
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
|
||||
By setting `enable_multiline_captions = true`, if the caption file consists of multiple lines, it will be used as is (including line breaks).
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
||||
detailed digital art of a girl Hatsune Miku with a microphone on a stage
|
||||
```
|
||||
|
||||
(The default is `false`, and in that case, the first line is used as the caption, and the rest are ignored.)
|
||||
|
||||
### Wildcard notation and multi-line captions
|
||||
|
||||
By setting `enable_wildcard = true`, if the caption file consists of multiple lines, one line is randomly selected as the caption.
|
||||
|
||||
NOTE: When `enable_wildcard` is enabled, `enable_multiline_captions` is ignored, and the caption file is always treated as a random selection of one line.
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
||||
|
||||
@@ -148,6 +148,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `enable_multiline_captions` | `true` | o | o | o |
|
||||
| `resize_interpolation` |(通常は設定しません) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
@@ -165,7 +166,10 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
|
||||
|
||||
* `enable_wildcard`
|
||||
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
||||
* ワイルドカード記法を有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
||||
|
||||
* `enable_multiline_captions`
|
||||
* 複数行キャプションを有効にします。キャプションファイルが複数の行からなる場合でも、そのまま利用されます。
|
||||
|
||||
* `resize_interpolation`
|
||||
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
|
||||
@@ -341,7 +345,15 @@ skip_image_resolution = 1024
|
||||
|
||||
### 複数行キャプション
|
||||
|
||||
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
|
||||
`enable_multiline_captions = true` を設定することで、キャプションファイルが複数の行からなる場合でも、そのまま利用されるようになります。
|
||||
|
||||
(未指定時のデフォルト動作は、キャプションファイルの最初の行のみが利用されます。)
|
||||
|
||||
### ワイルドカード記法
|
||||
|
||||
`enable_wildcard = true` を設定することで、キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
|
||||
|
||||
※`enable_wildcard`が指定されると、`enable_multiline_captions` は無視されます(複数行からランダムに一行選ぶ挙動が優先される)。
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
||||
|
||||
@@ -61,6 +61,7 @@ class BaseSubsetParams:
|
||||
keep_tokens_separator: str = (None,)
|
||||
secondary_separator: Optional[str] = None
|
||||
enable_wildcard: bool = False
|
||||
enable_multiline_captions: bool = False
|
||||
color_aug: bool = False
|
||||
flip_aug: bool = False
|
||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||
@@ -110,6 +111,7 @@ class BaseDatasetParams:
|
||||
resize_interpolation: Optional[str] = None
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -120,6 +122,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -193,6 +196,7 @@ class ConfigSanitizer:
|
||||
"secondary_separator": str,
|
||||
"caption_separator": str,
|
||||
"enable_wildcard": bool,
|
||||
"enable_multiline_captions": bool,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
@@ -473,7 +477,10 @@ class BlueprintGenerator:
|
||||
|
||||
return default_value
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
|
||||
|
||||
def generate_dataset_group_by_blueprint(
|
||||
dataset_group_blueprint: DatasetGroupBlueprint,
|
||||
) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
@@ -498,7 +505,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
|
||||
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
|
||||
logging.warning(
|
||||
f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split..."
|
||||
)
|
||||
continue
|
||||
|
||||
# if the dataset isn't setting a validation split, there is no current validation dataset
|
||||
@@ -527,27 +536,36 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
info += dedent(
|
||||
f"""\
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
skip_image_resolution: {dataset.skip_image_resolution}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of {dataset_type} {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
@@ -559,6 +577,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
enable_wildcard: {subset.enable_wildcard}
|
||||
enable_multiline_captions: {subset.enable_multiline_captions}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
@@ -568,18 +588,31 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
resize_interpolation: {subset.resize_interpolation}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(dedent(f"""\
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(info)
|
||||
|
||||
@@ -602,10 +635,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
return (DatasetGroup(datasets), DatasetGroup(val_datasets) if val_datasets else None)
|
||||
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
|
||||
@@ -421,6 +421,7 @@ class BaseSubset:
|
||||
keep_tokens_separator: str,
|
||||
secondary_separator: Optional[str],
|
||||
enable_wildcard: bool,
|
||||
enable_multiline_captions: bool,
|
||||
color_aug: bool,
|
||||
flip_aug: bool,
|
||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||
@@ -446,6 +447,7 @@ class BaseSubset:
|
||||
self.keep_tokens_separator = keep_tokens_separator
|
||||
self.secondary_separator = secondary_separator
|
||||
self.enable_wildcard = enable_wildcard
|
||||
self.enable_multiline_captions = enable_multiline_captions
|
||||
self.color_aug = color_aug
|
||||
self.flip_aug = flip_aug
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
@@ -485,6 +487,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
enable_multiline_captions,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -513,6 +516,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
enable_multiline_captions,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -556,6 +560,7 @@ class FineTuningSubset(BaseSubset):
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
enable_multiline_captions,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -584,6 +589,7 @@ class FineTuningSubset(BaseSubset):
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
enable_multiline_captions,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -623,6 +629,7 @@ class ControlNetSubset(BaseSubset):
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
enable_multiline_captions,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -651,6 +658,7 @@ class ControlNetSubset(BaseSubset):
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
enable_multiline_captions,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -863,6 +871,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
elif subset.enable_multiline_captions:
|
||||
# use multiline captions as they are
|
||||
pass
|
||||
else:
|
||||
# if caption is multiline, use the first line
|
||||
caption = caption.split("\n")[0]
|
||||
@@ -1954,7 +1965,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
def read_caption(img_path, caption_extension, enable_wildcard, enable_multiline_captions):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -1975,6 +1986,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
elif enable_multiline_captions:
|
||||
caption = "".join(lines).strip() # 改行も含めて連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
@@ -2095,7 +2108,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in tqdm(img_paths, desc="read caption"):
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
cap_for_img = read_caption(
|
||||
img_path, subset.caption_extension, subset.enable_wildcard, subset.enable_multiline_captions
|
||||
)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
@@ -6260,7 +6275,8 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
|
||||
)
|
||||
if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]:
|
||||
logs["lr/d*eff_lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"]
|
||||
* lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user