Fix regularization images with validation

Adding metadata recording for validation arguments
Add comments about the validation split for clarity of intention
This commit is contained in:
rockerBOO
2025-01-12 14:29:50 -05:00
parent 4c61adc996
commit 2bbb40ce51
2 changed files with 38 additions and 2 deletions

View File

@@ -146,7 +146,12 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]:
def split_train_val(
paths: List[str],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
) -> List[str]:
"""
Split the dataset into train and validation
@@ -1830,6 +1835,9 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
@@ -1965,8 +1973,29 @@ class DreamBoothDataset(BaseDataset):
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
#
# We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed)
# For regularization images we do not want to split this dataset.
if subset.is_reg is True:
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths = split_train_val(
img_paths,
self.is_training_dataset,
self.validation_split,
self.validation_seed
)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

View File

@@ -898,6 +898,7 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -917,6 +918,7 @@ class NetworkTrainer:
"ss_text_encoder_lr": text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0,
"ss_num_reg_images": train_dataset_group.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
@@ -964,6 +966,11 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
}
self.update_metadata(metadata, args) # architecture specific metadata