diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6..406f12f2 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging @@ -42,10 +42,7 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - if size > IMAGE_SIZE: - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) - else: - image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) + image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE) image = image.astype(np.float32) return image diff --git a/library/config_util.py b/library/config_util.py index a2e07dc6..53727f25 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,6 +75,7 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 + resize_interpolation: Optional[str] = None @dataclass @@ -106,7 +107,7 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 - + resize_interpolation: Optional[str] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -196,6 +197,7 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, + "resize_interpolation": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -241,6 +243,7 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "resize_interpolation": str, } # options handled by argparse but not handled by user config @@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} + resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) @@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, alpha_mask: {subset.alpha_mask} + resize_interpolation: {subset.resize_interpolation} custom_attributes: {subset.custom_attributes} """), " ") diff --git a/library/train_util.py b/library/train_util.py index 1f591c42..e9c50688 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,7 +74,7 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging @@ -205,6 +205,7 @@ class ImageInfo: self.text_encoder_pool2: Optional[torch.Tensor] = None self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.resize_interpolation: Optional[str] = None class BucketManager: @@ -429,6 +430,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -459,6 +461,8 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split + self.resize_interpolation = resize_interpolation + class DreamBoothSubset(BaseSubset): def __init__( @@ -490,6 +494,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -517,6 +522,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.is_reg = is_reg @@ -559,6 +565,7 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -586,6 +593,7 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.metadata_file = metadata_file @@ -624,6 +632,7 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -651,6 +660,7 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.conditioning_data_dir = conditioning_data_dir @@ -671,6 +681,7 @@ class BaseDataset(torch.utils.data.Dataset): resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, + resize_interpolation: Optional[str] = None ) -> None: super().__init__() @@ -705,6 +716,10 @@ class BaseDataset(torch.utils.data.Dataset): self.image_transforms = IMAGE_TRANSFORMS + if resize_interpolation is not None: + assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + self.resize_interpolation = resize_interpolation + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -1494,7 +1509,7 @@ class BaseDataset(torch.utils.data.Dataset): nh = int(height * scale + 0.5) nw = int(width * scale + 0.5) assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + image = resize_image(image, width, height, nw, nh, subset.resize_interpolation) face_cx = int(face_cx * scale + 0.5) face_cy = int(face_cy * scale + 0.5) height, width = nh, nw @@ -1591,7 +1606,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation ) else: if face_cx > 0: # 顔位置情報あり @@ -1852,8 +1867,9 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + resize_interpolation: Optional[str], ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2078,6 +2094,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation if size is not None: info.image_size = size if subset.is_reg: @@ -2360,9 +2377,10 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], + resize_interpolation: Optional[str] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) db_subsets = [] for subset in subsets: @@ -2394,6 +2412,7 @@ class ControlNetDataset(BaseDataset): subset.caption_suffix, subset.token_warmup_min, subset.token_warmup_step, + resize_interpolation=subset.resize_interpolation, ) db_subsets.append(db_subset) @@ -2412,6 +2431,7 @@ class ControlNetDataset(BaseDataset): debug_dataset, validation_split, validation_seed, + resize_interpolation, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2420,7 +2440,8 @@ class ControlNetDataset(BaseDataset): self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed + self.resize_interpolation = resize_interpolation # assert all conditioning data exists missing_imgs = [] @@ -2508,9 +2529,8 @@ class ControlNetDataset(BaseDataset): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize( - cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA - ) # INTER_AREAでやりたいのでcv2でリサイズ + + cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2524,7 +2544,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2921,17 +2941,13 @@ def load_image(image_path, alpha=False): # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - if image_width > resized_size[0] and image_height > resized_size[1]: - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - else: - image = pil_resize(image, resized_size) + image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation) image_height, image_width = image.shape[0:2] @@ -2976,7 +2992,7 @@ def load_images_and_masks_for_caching( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) original_sizes.append(original_size) crop_ltrbs.append(crop_ltrb) @@ -3017,7 +3033,7 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -4495,7 +4511,13 @@ def add_dataset_arguments( action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) - + parser.add_argument( + "--resize_interpolation", + type=str, + default=None, + choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"], + help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area", + ) parser.add_argument( "--token_warmup_min", type=int, @@ -6533,3 +6555,4 @@ class LossRecorder: if losses == 0: return 0 return self.loss_total / losses + diff --git a/library/utils.py b/library/utils.py index 4df8bd32..4fbc2627 100644 --- a/library/utils.py +++ b/library/utils.py @@ -16,7 +16,6 @@ from PIL import Image import numpy as np from safetensors.torch import load_file - def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +setup_logging() +logger = logging.getLogger(__name__) # endregion @@ -378,7 +379,7 @@ def load_safetensors( # region Image utils -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: @@ -386,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): else: pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - resized_pil = pil_image.resize(size, interpolation) + resized_pil = pil_image.resize(size, resample=interpolation) # Convert back to cv2 format if has_alpha: @@ -397,6 +398,100 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 +def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): + """ + Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS + + Args: + image: numpy.ndarray + width: int Original image width + height: int Original image height + resized_width: int Resized image width + resized_height: int Resized image height + resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box" + + Returns: + image + """ + interpolation = get_cv2_interpolation(resize_interpolation) + resized_size = (resized_width, resized_height) + if width > resized_width and height > resized_width: + image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + logger.debug(f"resize image using {resize_interpolation}") + else: + image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ + logger.debug(f"resize image using {resize_interpolation}") + + return image + + +def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: + """ + Convert interpolation value to cv2 interpolation integer + + https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 + """ + if interpolation is None: + return None + + if interpolation == "lanczos" or interpolation == "lanczos4": + # Lanczos interpolation over 8x8 neighborhood + return cv2.INTER_LANCZOS4 + elif interpolation == "nearest": + # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. + return cv2.INTER_NEAREST_EXACT + elif interpolation == "bilinear" or interpolation == "linear": + # bilinear interpolation + return cv2.INTER_LINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + # bicubic interpolation + return cv2.INTER_CUBIC + elif interpolation == "area": + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + return cv2.INTER_AREA + elif interpolation == "box": + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + return cv2.INTER_AREA + else: + return None + +def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: + """ + Convert interpolation value to PIL interpolation + + https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters + """ + if interpolation is None: + return None + + if interpolation == "lanczos": + return Image.Resampling.LANCZOS + elif interpolation == "nearest": + # Pick one nearest pixel from the input image. Ignore all other input pixels. + return Image.Resampling.NEAREST + elif interpolation == "bilinear" or interpolation == "linear": + # For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used. + return Image.Resampling.BILINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. + return Image.Resampling.BICUBIC + elif interpolation == "area": + # Image.Resampling.BOX may be more appropriate if upscaling + # Area interpolation is related to cv2.INTER_AREA + # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. + return Image.Resampling.HAMMING + elif interpolation == "box": + # Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST. + return Image.Resampling.BOX + else: + return None + +def validate_interpolation_fn(interpolation_str: str) -> bool: + """ + Check if a interpolation function is supported + """ + return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + # endregion # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index d2a4d9cf..16fd7d0b 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging logger = logging.getLogger(__name__) @@ -170,12 +170,9 @@ def process(args): scale = max(cur_crop_width / w, cur_crop_height / h) if scale != 1.0: - w = int(w * scale + .5) - h = int(h * scale + .5) - if scale < 1.0: - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) - else: - face_img = pil_resize(face_img, (w, h)) + rw = int(w * scale + .5) + rh = int(h * scale + .5) + face_img = resize_image(face_img, w, h, rw, rh) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 0f9e00b1..f5fbae2b 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import shutil import math from PIL import Image import numpy as np -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging logger = logging.getLogger(__name__) @@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi if not os.path.exists(dst_img_folder): os.makedirs(dst_img_folder) - # Select interpolation method - if interpolation == 'lanczos4': - pil_interpolation = Image.LANCZOS - elif interpolation == 'cubic': - pil_interpolation = Image.BICUBIC - else: - cv2_interpolation = cv2.INTER_AREA - # Iterate through all files in src_img_folder img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py for filename in os.listdir(src_img_folder): @@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_height = int(img.shape[0] * math.sqrt(scale_factor)) new_width = int(img.shape[1] * math.sqrt(scale_factor)) - # Resize image - if cv2_interpolation: - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) - else: - img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) + img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation) else: new_height, new_width = img.shape[0:2] @@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser: help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) - parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], - default='area', help='Interpolation method for resizing / リサイズ時の補完方法') + parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'], + default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。') parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') parser.add_argument('--copy_associated_files', action='store_true', help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') diff --git a/train_network.py b/train_network.py index 2d279b3b..9b89c98f 100644 --- a/train_network.py +++ b/train_network.py @@ -1012,11 +1012,12 @@ class NetworkTrainer: "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_resize_interpolation": args.resize_interpolation } self.update_metadata(metadata, args) # architecture specific metadata @@ -1042,6 +1043,7 @@ class NetworkTrainer: "max_bucket_reso": dataset.max_bucket_reso, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, + "resize_interpolation": dataset.resize_interpolation, } subsets_metadata = [] @@ -1059,6 +1061,7 @@ class NetworkTrainer: "enable_wildcard": bool(subset.enable_wildcard), "caption_prefix": subset.caption_prefix, "caption_suffix": subset.caption_suffix, + "resize_interpolation": subset.resize_interpolation, } image_dir_or_metadata_file = None