mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset * add cache meta scripts * random ip_noise_gamma strength * random noise_offset strength * use correct settings for parser * cache path/caption/size only * revert mess up commit * revert mess up commit * Update requirements.txt * Add arguments for meta cache. * remove pickle implementation * Return sizes when enable cache --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
103
cache_dataset_meta.py
Normal file
103
cache_dataset_meta.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import random
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_dataset(args):
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
use_user_config = args.dataset_config is not None
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(
|
||||
ConfigSanitizer(True, True, False, True)
|
||||
)
|
||||
if use_user_config:
|
||||
logger.info(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
|
||||
blueprint.dataset_group
|
||||
)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
if args.max_token_length is None:
|
||||
args.max_token_length = 75
|
||||
args.cache_meta = True
|
||||
|
||||
dataset_group = make_dataset(args)
|
||||
@@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
cache_meta: bool = False
|
||||
use_cached_meta: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -228,6 +230,8 @@ class ConfigSanitizer:
|
||||
"min_bucket_reso": int,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"cache_meta": bool,
|
||||
"use_cached_meta": bool,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
|
||||
@@ -63,6 +63,7 @@ from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import imagesize
|
||||
import cv2
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
@@ -1080,8 +1081,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
return image.size
|
||||
return imagesize.get(image_path)
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||
img = load_image(image_path)
|
||||
@@ -1425,6 +1425,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
cache_meta: bool,
|
||||
use_cached_meta: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -1484,26 +1486,43 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = None
|
||||
if use_cached_meta:
|
||||
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
|
||||
# [img_path, caption, resolution]
|
||||
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
|
||||
metas = f.readlines()
|
||||
metas = [x.strip().split("<|##|>") for x in metas]
|
||||
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
|
||||
|
||||
if use_cached_meta:
|
||||
img_paths = [x[0] for x in metas]
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = [None]*len(img_paths)
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
if use_cached_meta:
|
||||
captions = [x[1] for x in metas]
|
||||
missing_captions = [x[0] for x in metas if x[1] == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
@@ -1520,7 +1539,21 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
logger.warning(missing_caption)
|
||||
return img_paths, captions
|
||||
|
||||
if cache_meta:
|
||||
logger.info(f"cache metadata for {subset.image_dir}")
|
||||
if sizes is None or sizes[0] is None:
|
||||
sizes = [self.get_image_size(img_path) for img_path in img_paths]
|
||||
# [img_path, caption, resolution]
|
||||
data = [
|
||||
(img_path, caption, " ".join(str(x) for x in size))
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes)
|
||||
]
|
||||
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(["<|##|>".join(x) for x in data]))
|
||||
logger.info(f"cache metadata done for {subset.image_dir}")
|
||||
|
||||
return img_paths, captions, sizes
|
||||
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
@@ -1539,7 +1572,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, captions = load_dreambooth_dir(subset)
|
||||
img_paths, captions, sizes = load_dreambooth_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
||||
@@ -1551,8 +1584,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
@@ -3355,6 +3390,12 @@ def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
):
|
||||
# dataset common
|
||||
parser.add_argument(
|
||||
"--cache_meta", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cached_meta", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
|
||||
@@ -15,6 +15,8 @@ easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.20.1
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import pickle
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
@@ -23,7 +24,7 @@ from library import model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.train_util import (
|
||||
DreamBoothDataset,
|
||||
DreamBoothDataset, DatasetGroup
|
||||
)
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
|
||||
Reference in New Issue
Block a user