mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Fix training, validation split, revert to using upstream implemenation
This commit is contained in:
@@ -482,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params))
|
||||
datasets.append(dataset)
|
||||
|
||||
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
@@ -500,16 +500,16 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params))
|
||||
val_datasets.append(dataset)
|
||||
|
||||
def print_info(_datasets):
|
||||
def print_info(_datasets, dataset_type: str):
|
||||
info = ""
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[Dataset {i}]
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
@@ -527,7 +527,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
[Subset {j} of {dataset_type} {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
@@ -544,8 +544,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
@@ -561,67 +561,22 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
|
||||
logger.info(info)
|
||||
|
||||
print_info(datasets)
|
||||
print_info(datasets, "Dataset")
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
logger.info("Validation dataset")
|
||||
print_info(val_datasets)
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
info = ""
|
||||
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Validation Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Validation Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(info)
|
||||
print_info(val_datasets, "Validation Dataset")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
logger.info(f"[Prepare dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
logger.info(f"[Validation Dataset {i}]")
|
||||
logger.info(f"[Prepare validation dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
|
||||
@@ -455,7 +455,7 @@ def get_weighted_text_embeddings(
|
||||
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
@@ -468,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
|
||||
if noise_offset is None:
|
||||
return noise
|
||||
if adaptive_noise_scale is not None:
|
||||
@@ -484,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
if "conditioning_images" in batch:
|
||||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
||||
@@ -40,7 +40,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
|
||||
@@ -146,7 +146,15 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]:
|
||||
def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
|
||||
Shuffle the dataset based on the validation_seed or the current random seed.
|
||||
For example if the split of 0.2 of 100 images.
|
||||
[0:79] = 80 training images
|
||||
[80:] = 20 validation images
|
||||
"""
|
||||
if validation_seed is not None:
|
||||
print(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
@@ -156,9 +164,12 @@ def split_train_val(paths: List[str], is_train: bool, validation_split: float, v
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
if is_train:
|
||||
# Split the dataset between training and validation
|
||||
if is_training_dataset:
|
||||
# Training dataset we split to the first part
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
# Validation dataset we split to the second part
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
|
||||
@@ -1822,6 +1833,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
is_training_dataset: bool,
|
||||
batch_size: int,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
@@ -1843,6 +1855,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.latents_cache = None
|
||||
self.is_training_dataset = is_training_dataset
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
@@ -1952,6 +1965,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
if self.validation_split > 0.0:
|
||||
img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
@@ -2046,7 +2062,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
logger.info(f"{num_train_images} train images with repeats.")
|
||||
images_split_name = "train" if self.is_training_dataset else "validation"
|
||||
logger.info(f"{num_train_images} {images_split_name} images with repeats.")
|
||||
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
@@ -2411,8 +2428,12 @@ class ControlNetDataset(BaseDataset):
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -4586,7 +4607,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = os.path.splitext(args.config_file)[0]
|
||||
logger.info(args.config_file)
|
||||
|
||||
return args
|
||||
|
||||
@@ -5880,55 +5900,35 @@ def save_sd_model_on_train_end_common(
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: int, device: torch.device) -> torch.IntTensor:
|
||||
"""
|
||||
Get a random timestep between the min and max timesteps
|
||||
Can error (NotImplementedError) if the loss type is not supported
|
||||
"""
|
||||
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
|
||||
# as. In the future there may be a smarter way
|
||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
|
||||
timesteps = timesteps.repeat(batch_size).to(device)
|
||||
elif args.loss_type == "l2":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
|
||||
|
||||
return typing.cast(torch.IntTensor, timesteps)
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]:
|
||||
"""
|
||||
Calculate the Huber convolution (huber_c) value
|
||||
Huber loss is a loss function used in robust regression, that is less sensitive
|
||||
to outliers in data than the squared error loss.
|
||||
https://en.wikipedia.org/wiki/Huber_loss
|
||||
"""
|
||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000)
|
||||
huber_c = math.exp(-alpha * timesteps.item())
|
||||
elif args.huber_schedule == "snr":
|
||||
if not hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps)
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
elif args.huber_schedule == "constant":
|
||||
huber_c = args.huber_c
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
elif args.loss_type == "l2":
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
b_size = timesteps.shape[0]
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
result = torch.exp(-alpha * timesteps) * args.huber_scale
|
||||
elif args.huber_schedule == "snr":
|
||||
if not hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
result = result.to(timesteps.device)
|
||||
elif args.huber_schedule == "constant":
|
||||
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
return huber_c
|
||||
return result
|
||||
|
||||
|
||||
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor):
|
||||
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Apply noise modifications like noise offset and multires noise
|
||||
"""
|
||||
@@ -5964,27 +5964,44 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int,
|
||||
max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, device)
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device)
|
||||
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]:
|
||||
"""
|
||||
Unified noise, noisy_latents, timesteps and huber loss convolution calculations
|
||||
"""
|
||||
batch_size = latents.shape[0]
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
if args.noise_offset_random_strength:
|
||||
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
|
||||
else:
|
||||
noise_offset = args.noise_offset
|
||||
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
|
||||
if args.multires_noise_iterations:
|
||||
noise = custom_train_functions.pyramid_noise_like(
|
||||
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
b_size = latents.shape[0]
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.get("num_train_timesteps", 1000) if args.max_timestep is None else args.max_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
# A random timestep for each image in the batch
|
||||
timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, latents.device)
|
||||
huber_c = get_huber_c(args, noise_scheduler, timesteps)
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
|
||||
|
||||
noise = make_noise(args, latents)
|
||||
noisy_latents = get_noisy_latents(args, noise, noise_scheduler, latents, timesteps)
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
|
||||
else:
|
||||
strength = args.ip_noise_gamma
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
return noise, noisy_latents, timesteps, huber_c
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
|
||||
@@ -6015,6 +6032,8 @@ def conditional_loss(
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
@@ -6022,6 +6041,8 @@ def conditional_loss(
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif loss_type == "smooth_l1":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
|
||||
@@ -205,10 +205,10 @@ class NetworkTrainer:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(images).latent_dist.sample()
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return latents * self.vae_scale_factor
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
@@ -280,7 +280,7 @@ class NetworkTrainer:
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -317,20 +317,21 @@ class NetworkTrainer:
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample())
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents))
|
||||
latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor)
|
||||
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
|
||||
|
||||
text_encoder_conds = []
|
||||
@@ -384,22 +385,36 @@ class NetworkTrainer:
|
||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
||||
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
for fixed_timestep in chosen_timesteps_list:
|
||||
fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep)
|
||||
for fixed_timesteps in chosen_timesteps_list:
|
||||
fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
# and add noise to the latents
|
||||
# with noise offset and/or multires noise if specified
|
||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep)
|
||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timestep)
|
||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -418,7 +433,7 @@ class NetworkTrainer:
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
@@ -427,7 +442,8 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
@@ -436,14 +452,7 @@ class NetworkTrainer:
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler)
|
||||
loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler)
|
||||
|
||||
total_loss += loss
|
||||
|
||||
@@ -526,8 +535,12 @@ class NetworkTrainer:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(val_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
@@ -753,10 +766,6 @@ class NetworkTrainer:
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# Not for sure here.
|
||||
# if val_dataset_group is not None:
|
||||
# val_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
@@ -1304,7 +1313,7 @@ class NetworkTrainer:
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
@@ -1324,7 +1333,7 @@ class NetworkTrainer:
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
@@ -1384,7 +1393,8 @@ class NetworkTrainer:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate = (args.validation_every_n_step is not None
|
||||
@@ -1401,7 +1411,7 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||
val_progress_bar.update(1)
|
||||
@@ -1409,10 +1419,12 @@ class NetworkTrainer:
|
||||
|
||||
if is_tracking:
|
||||
logs = {"loss/current_val_loss": loss.detach().item()}
|
||||
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
logs = {"loss/average_val_loss": val_loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=global_step)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
@@ -1427,7 +1439,7 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
@@ -1437,22 +1449,26 @@ class NetworkTrainer:
|
||||
if is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_current": current_loss}
|
||||
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
if len(val_dataloader) > 0 and is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_epoch_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user