Merge pull request #1864 from rockerBOO/validation-loss-upstream

Validation loss
This commit is contained in:
Kohya S.
2025-01-26 21:08:14 +09:00
committed by GitHub
24 changed files with 668 additions and 265 deletions

View File

@@ -91,9 +91,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -138,9 +138,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -126,9 +126,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -2,7 +2,7 @@ import argparse
import copy
import math
import random
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
@@ -36,8 +36,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
@@ -80,6 +80,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.blocks_to_swap = 18 # 18 is safe for most cases
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
@@ -339,6 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -375,7 +378,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
@@ -420,7 +423,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred

View File

@@ -73,6 +73,8 @@ class BaseSubsetParams:
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
@dataclass
@@ -102,6 +104,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
@dataclass
@@ -113,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
@@ -234,6 +237,8 @@ class ConfigSanitizer:
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
}
@@ -462,119 +467,136 @@ class BlueprintGenerator:
return default_value
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
extra_dataset_params = {}
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": True}
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
datasets.append(dataset)
# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
continue
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
# if the dataset isn't setting a validation split, there is no current validation dataset
if dataset_blueprint.params.validation_split == 0.0:
continue
extra_dataset_params = {}
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": False}
else:
info += "\n"
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""
),
" ",
)
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
val_datasets.append(dataset)
if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
def print_info(_datasets, dataset_type: str):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")
logger.info(f"{info}")
if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of {dataset_type} {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""), " ")
if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
logger.info(info)
print_info(datasets, "Dataset")
if len(val_datasets) > 0:
print_info(val_datasets, "Validation Dataset")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
logger.info(f"[Prepare dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets)
for i, dataset in enumerate(val_datasets):
logger.info(f"[Prepare validation dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):

View File

@@ -1,7 +1,9 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from .utils import setup_logging
@@ -63,7 +65,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
@@ -74,13 +76,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps, noise_scheduler):
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
@@ -89,14 +91,14 @@ def get_snr_scale(timesteps, noise_scheduler):
return scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
@@ -453,7 +455,7 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
@@ -466,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
@@ -482,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel

View File

@@ -40,7 +40,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []

View File

@@ -12,6 +12,7 @@ import pathlib
import re
import shutil
import time
import typing
from typing import (
Any,
Callable,
@@ -145,6 +146,37 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(
paths: List[str],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
) -> List[str]:
"""
Split the dataset into train and validation
Shuffle the dataset based on the validation_seed or the current random seed.
For example if the split of 0.2 of 100 images.
[0:80] = 80 training images
[80:] = 20 validation images
"""
if validation_seed is not None:
logging.info(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)
# Split the dataset between training and validation
if is_training_dataset:
# Training dataset we split to the first part
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
# Validation dataset we split to the second part
return paths[len(paths) - round(len(paths) * validation_split):]
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -397,6 +429,8 @@ class BaseSubset:
token_warmup_min: int,
token_warmup_step: Union[float, int],
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -424,6 +458,9 @@ class BaseSubset:
self.img_count = 0
self.validation_seed = validation_seed
self.validation_split = validation_split
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -453,6 +490,8 @@ class DreamBoothSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -478,6 +517,8 @@ class DreamBoothSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.is_reg = is_reg
@@ -518,6 +559,8 @@ class FineTuningSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -543,6 +586,8 @@ class FineTuningSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.metadata_file = metadata_file
@@ -579,6 +624,8 @@ class ControlNetSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -604,6 +651,8 @@ class ControlNetSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -1786,9 +1835,13 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
is_training_dataset: bool,
batch_size: int,
resolution,
network_multiplier: float,
@@ -1799,6 +1852,8 @@ class DreamBoothDataset(BaseDataset):
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -1808,6 +1863,9 @@ class DreamBoothDataset(BaseDataset):
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.is_training_dataset = is_training_dataset
self.validation_seed = validation_seed
self.validation_split = validation_split
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -1915,6 +1973,30 @@ class DreamBoothDataset(BaseDataset):
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
#
# We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0:
# For regularization images we do not want to split this dataset.
if subset.is_reg is True:
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths = split_train_val(
img_paths,
self.is_training_dataset,
self.validation_split,
self.validation_seed
)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset:
@@ -1973,9 +2055,10 @@ class DreamBoothDataset(BaseDataset):
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
if subset.num_repeats < 1:
num_repeats = subset.num_repeats if self.is_training_dataset else 1
if num_repeats < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}"
)
continue
@@ -1993,12 +2076,12 @@ class DreamBoothDataset(BaseDataset):
continue
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
num_reg_images += num_repeats * len(img_paths)
else:
num_train_images += subset.num_repeats * len(img_paths)
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2009,10 +2092,12 @@ class DreamBoothDataset(BaseDataset):
subset.img_count = len(img_paths)
self.subsets.append(subset)
logger.info(f"{num_train_images} train images with repeating.")
images_split_name = "train" if self.is_training_dataset else "validation"
logger.info(f"{num_train_images} {images_split_name} images with repeats.")
self.num_train_images = num_train_images
logger.info(f"{num_reg_images} reg images.")
logger.info(f"{num_reg_images} reg images with repeats.")
if num_train_images < num_reg_images:
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
@@ -2050,6 +2135,8 @@ class FineTuningDataset(BaseDataset):
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2275,7 +2362,9 @@ class ControlNetDataset(BaseDataset):
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: float,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2324,13 +2413,17 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale,
1.0,
debug_dataset,
validation_split,
validation_seed,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
# assert all conditioning data exists
missing_imgs = []
@@ -2800,6 +2893,9 @@ class MinimalDataset(BaseDataset):
"""
raise NotImplementedError
def get_resolutions(self) -> List[Tuple[int, int]]:
return []
def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
@@ -4544,7 +4640,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = os.path.splitext(args.config_file)[0]
logger.info(args.config_file)
return args
@@ -4887,7 +4982,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "RAdamScheduleFree".lower():
optimizer_class = sf.RAdamScheduleFree
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")
@@ -5838,13 +5933,13 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_timesteps(min_timestep, max_timestep, b_size, device):
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
timesteps = timesteps.long().to(device)
return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -5905,11 +6000,16 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
):
"""
NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
"""
if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
if huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
@@ -5917,6 +6017,8 @@ def conditional_loss(
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
if huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
@@ -6329,6 +6431,30 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
"""
Initialize experiment trackers with tracker specific behaviors
"""
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
default_tracker_name if args.log_tracker_name is None else args.log_tracker_name,
config=get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
# Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True)
# endregion
@@ -6397,4 +6523,7 @@ class LossRecorder:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
losses = len(self.loss_list)
if losses == 0:
return 0
return self.loss_total / losses

View File

@@ -20,6 +20,7 @@ voluptuous==0.13.1
huggingface-hub==0.24.5
# for Image utils
imagesize==1.4.1
numpy<=2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12

View File

@@ -149,9 +149,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -2,7 +2,7 @@ import argparse
import copy
import math
import random
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
@@ -26,7 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
@@ -56,9 +56,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
self.resolutions = train_dataset_group.get_resolutions()
resolutions = train_dataset_group.get_resolutions()
if val_dataset_group is not None:
resolutions = resolutions + val_dataset_group.get_resolutions()
self.resolutions = resolutions
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
@@ -312,6 +317,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -339,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None
# call model
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)

View File

@@ -176,9 +176,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -114,7 +114,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -123,7 +123,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -103,7 +103,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -1,5 +1,5 @@
import argparse
from typing import List, Optional
from typing import List, Optional, Union
import torch
from accelerate import Accelerator
@@ -23,8 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
@@ -37,6 +37,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(

View File

@@ -18,11 +18,12 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(

View File

@@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
# acceleratorを準備する
logger.info("prepare accelerator")

View File

@@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
# acceleratorを準備する
logger.info("prepare accelerator")

View File

@@ -100,7 +100,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -89,9 +89,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -2,17 +2,19 @@ import importlib
import argparse
import math
import os
import typing
from typing import Any, List, Union, Optional
import sys
import random
import time
import json
from multiprocessing import Value
from typing import Any, List
import toml
from tqdm import tqdm
import torch
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -20,6 +22,7 @@ init_ipex()
from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util
@@ -114,15 +117,17 @@ class NetworkTrainer:
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
):
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
return logs
def assert_extra_args(self, args, train_dataset_group):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -196,10 +201,10 @@ class NetworkTrainer:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(images).latent_dist.sample()
def shift_scale_latents(self, args, latents):
def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor:
return latents * self.vae_scale_factor
def get_noise_pred_and_target(
@@ -214,6 +219,7 @@ class NetworkTrainer:
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -227,7 +233,7 @@ class NetworkTrainer:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
@@ -271,7 +277,7 @@ class NetworkTrainer:
return noise_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
@@ -308,6 +314,107 @@ class NetworkTrainer:
# endregion
def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True
) -> torch.Tensor:
"""
Process a batch for the network
"""
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
latents = self.shift_scale_latents(args, latents)
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids_list,
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=is_train
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -373,10 +480,11 @@ class NetworkTrainer:
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -384,8 +492,12 @@ class NetworkTrainer:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.debug_dataset:
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
train_util.debug_dataset(train_dataset_group)
if val_dataset_group is not None:
val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
train_util.debug_dataset(val_dataset_group)
return
if len(train_dataset_group) == 0:
logger.error(
@@ -397,8 +509,12 @@ class NetworkTrainer:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if val_dataset_group is not None:
assert (
val_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group) # may change some args
self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args
# acceleratorを準備する
logger.info("preparing accelerator")
@@ -444,6 +560,8 @@ class NetworkTrainer:
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
if val_dataset_group is not None:
val_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -459,6 +577,8 @@ class NetworkTrainer:
if text_encoder_outputs_caching_strategy is not None:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
# prepare network
net_kwargs = {}
@@ -567,6 +687,8 @@ class NetworkTrainer:
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
if val_dataset_group is not None:
val_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -579,6 +701,15 @@ class NetworkTrainer:
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -654,8 +785,8 @@ class NetworkTrainer:
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)
training_model = ds_model
else:
@@ -676,8 +807,8 @@ class NetworkTrainer:
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
network, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, val_dataloader, lr_scheduler
)
training_model = network
@@ -769,6 +900,7 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -788,6 +920,7 @@ class NetworkTrainer:
"ss_text_encoder_lr": text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0,
"ss_num_reg_images": train_dataset_group.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
@@ -835,6 +968,11 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1051,20 +1189,15 @@ class NetworkTrainer:
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
train_util.init_trackers(accelerator, args, "network_train")
loss_recorder = train_util.LossRecorder()
val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
del train_dataset_group
if val_dataset_group is not None:
del val_dataset_group
# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
@@ -1109,10 +1242,17 @@ class NetworkTrainer:
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
is_tracking = len(accelerator.trackers) > 0
if is_tracking:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
@@ -1132,13 +1272,14 @@ class NetworkTrainer:
clean_memory_on_device(accelerator.device)
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
# TRAINING
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
@@ -1156,98 +1297,24 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
latents = latents.to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = self.shift_scale_latents(args, latents)
# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError("multipliers for each sample is not supported yet")
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids_list,
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
@@ -1302,19 +1369,148 @@ class NetworkTrainer:
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if len(accelerator.trackers) > 0:
if is_tracking:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
)
accelerator.log(logs, step=global_step)
# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)
if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
# EPOCH VALIDATION
should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0
if args.validate_every_n_epochs is not None
else True
)
if should_validate_epoch and len(val_dataloader) > 0:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps"
)
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
if is_tracking:
logs = {
"loss/validation/epoch_current": current_loss,
"epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step
}
accelerator.log(logs, step=global_step)
if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
}
accelerator.log(logs, step=global_step)
# END OF EPOCH
if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
@@ -1496,9 +1692,36 @@ def setup_parser() -> argparse.ArgumentParser:
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
parser.add_argument(
"--validation_seed",
type=int,
default=None,
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
)
parser.add_argument(
"--validate_every_n_steps",
type=int,
default=None,
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
)
parser.add_argument(
"--validate_every_n_epochs",
type=int,
default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
)
return parser

View File

@@ -2,7 +2,7 @@ import argparse
import math
import os
from multiprocessing import Value
from typing import Any, List
from typing import Any, List, Optional, Union
import toml
from tqdm import tqdm
@@ -99,9 +99,12 @@ class TextualInversionTrainer:
self.vae_scale_factor = 0.18215
self.is_sdxl = False
def assert_extra_args(self, args, train_dataset_group):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
@@ -320,11 +323,12 @@ class TextualInversionTrainer:
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
self.assert_extra_args(args, train_dataset_group)
self.assert_extra_args(args, train_dataset_group, val_dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)

View File

@@ -239,7 +239,7 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)