Merge pull request #1936 from rockerBOO/resize-interpolation

Add resize interpolation parameter
This commit is contained in:
Kohya S.
2025-03-30 20:37:34 +09:00
committed by GitHub
7 changed files with 165 additions and 57 deletions

View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
@@ -42,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
image = image.astype(np.float32)
return image

View File

@@ -75,6 +75,7 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
@dataclass
@@ -106,7 +107,7 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -196,6 +197,7 @@ class ConfigSanitizer:
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"resize_interpolation": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -241,6 +243,7 @@ class ConfigSanitizer:
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
}
# options handled by argparse but not handled by user config
@@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")
@@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
"""), " ")

View File

@@ -74,7 +74,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
@@ -205,6 +205,7 @@ class ImageInfo:
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
class BucketManager:
@@ -429,6 +430,7 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -459,6 +461,8 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.resize_interpolation = resize_interpolation
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -490,6 +494,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -517,6 +522,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)
self.is_reg = is_reg
@@ -559,6 +565,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -586,6 +593,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)
self.metadata_file = metadata_file
@@ -624,6 +632,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -651,6 +660,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -671,6 +681,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
) -> None:
super().__init__()
@@ -705,6 +716,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = IMAGE_TRANSFORMS
if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
self.resize_interpolation = resize_interpolation
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -1494,7 +1509,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
@@ -1591,7 +1606,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -1852,8 +1867,9 @@ class DreamBoothDataset(BaseDataset):
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2078,6 +2094,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2360,9 +2377,10 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
db_subsets = []
for subset in subsets:
@@ -2394,6 +2412,7 @@ class ControlNetDataset(BaseDataset):
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
resize_interpolation=subset.resize_interpolation,
)
db_subsets.append(db_subset)
@@ -2412,6 +2431,7 @@ class ControlNetDataset(BaseDataset):
debug_dataset,
validation_split,
validation_seed,
resize_interpolation,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2420,7 +2440,8 @@ class ControlNetDataset(BaseDataset):
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
self.resize_interpolation = resize_interpolation
# assert all conditioning data exists
missing_imgs = []
@@ -2508,9 +2529,8 @@ class ControlNetDataset(BaseDataset):
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2524,7 +2544,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2921,17 +2941,13 @@ def load_image(image_path, alpha=False):
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
image_height, image_width = image.shape[0:2]
@@ -2976,7 +2992,7 @@ def load_images_and_masks_for_caching(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
@@ -3017,7 +3033,7 @@ def cache_batch_latents(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -4495,7 +4511,13 @@ def add_dataset_arguments(
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--resize_interpolation",
type=str,
default=None,
choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"],
help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area",
)
parser.add_argument(
"--token_warmup_min",
type=int,
@@ -6533,3 +6555,4 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses

View File

@@ -16,7 +16,6 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
# endregion
@@ -378,7 +379,7 @@ def load_safetensors(
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
@@ -386,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
resized_pil = pil_image.resize(size, resample=interpolation)
# Convert back to cv2 format
if has_alpha:
@@ -397,6 +398,100 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS
Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
Returns:
image
"""
interpolation = get_cv2_interpolation(resize_interpolation)
resized_size = (resized_width, resized_height)
if width > resized_width and height > resized_width:
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
logger.debug(f"resize image using {resize_interpolation}")
else:
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ
logger.debug(f"resize image using {resize_interpolation}")
return image
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion
# TODO make inf_utils.py

View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -170,12 +170,9 @@ def process(args):
scale = max(cur_crop_width / w, cur_crop_height / h)
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
rw = int(w * scale + .5)
rh = int(h * scale + .5)
face_img = resize_image(face_img, w, h, rw, rh)
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)

View File

@@ -6,7 +6,7 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
else:
new_height, new_width = img.shape[0:2]
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / サイズの補方法')
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')

View File

@@ -1012,11 +1012,12 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1042,6 +1043,7 @@ class NetworkTrainer:
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
}
subsets_metadata = []
@@ -1059,6 +1061,7 @@ class NetworkTrainer:
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
"resize_interpolation": subset.resize_interpolation,
}
image_dir_or_metadata_file = None