mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Fix sizes for validation split
This commit is contained in:
@@ -148,10 +148,11 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
sizes: List[Optional[Tuple[int, int]]],
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None
|
||||
) -> List[str]:
|
||||
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
|
||||
@@ -172,10 +173,12 @@ def split_train_val(
|
||||
# Split the dataset between training and validation
|
||||
if is_training_dataset:
|
||||
# Training dataset we split to the first part
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
split = math.ceil(len(paths) * (1 - validation_split))
|
||||
return paths[0:split], sizes[0:split]
|
||||
else:
|
||||
# Validation dataset we split to the second part
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
split = len(paths) - round(len(paths) * validation_split)
|
||||
return paths[split:], sizes[split:]
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
@@ -1931,12 +1934,12 @@ class DreamBoothDataset(BaseDataset):
|
||||
with open(info_cache_file, "r", encoding="utf-8") as f:
|
||||
metas = json.load(f)
|
||||
img_paths = list(metas.keys())
|
||||
sizes = [meta["resolution"] for meta in metas.values()]
|
||||
sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()]
|
||||
|
||||
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = [None] * len(img_paths)
|
||||
sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths)
|
||||
|
||||
# new caching: get image size from cache files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
@@ -1969,7 +1972,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
w, h = None, None
|
||||
|
||||
if w is not None and h is not None:
|
||||
sizes[i] = [w, h]
|
||||
sizes[i] = (w, h)
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
@@ -1990,8 +1993,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# required for training images dataset of regularization images
|
||||
else:
|
||||
img_paths = split_train_val(
|
||||
img_paths, sizes = split_train_val(
|
||||
img_paths,
|
||||
sizes,
|
||||
self.is_training_dataset,
|
||||
self.validation_split,
|
||||
self.validation_seed
|
||||
|
||||
Reference in New Issue
Block a user