diff --git a/library/config_util.py b/library/config_util.py index 5a4d3aa2..63d28c96 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -482,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -500,16 +500,16 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): + def print_info(_datasets, dataset_type: str): info = "" for i, dataset in enumerate(_datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ - [Dataset {i}] + [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} @@ -527,7 +527,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] + [Subset {j} of {dataset_type} {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} num_repeats: {subset.num_repeats} @@ -544,8 +544,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask} - custom_attributes: {subset.custom_attributes} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """), " ") if is_dreambooth: @@ -561,67 +561,22 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(info) - print_info(datasets) + print_info(datasets, "Dataset") if len(val_datasets) > 0: - logger.info("Validation dataset") - print_info(val_datasets) - - if len(val_datasets) > 0: - info = "" - - for i, dataset in enumerate(val_datasets): - info += dedent( - f"""\ - [Validation Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) - - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Validation Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - """ - ), - " ", - ) - - logger.info(info) + print_info(val_datasets, "Validation Dataset") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + logger.info(f"[Prepare dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - logger.info(f"[Validation Dataset {i}]") + logger.info(f"[Prepare validation dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9a7c21a3..ad3e69ff 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -455,7 +455,7 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): +def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): @@ -468,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4): # https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: @@ -484,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel diff --git a/library/strategy_sd.py b/library/strategy_sd.py index d0a3a68b..a44fc409 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,7 +40,7 @@ class SdTokenizeStrategy(TokenizeStrategy): text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text tokens_list = [] weights_list = [] diff --git a/library/train_util.py b/library/train_util.py index 3710c865..0f16a4f3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,15 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: + """ + Split the dataset into train and validation + + Shuffle the dataset based on the validation_seed or the current random seed. + For example if the split of 0.2 of 100 images. + [0:79] = 80 training images + [80:] = 20 validation images + """ if validation_seed is not None: print(f"Using validation seed: {validation_seed}") prevstate = random.getstate() @@ -156,9 +164,12 @@ def split_train_val(paths: List[str], is_train: bool, validation_split: float, v else: random.shuffle(paths) - if is_train: + # Split the dataset between training and validation + if is_training_dataset: + # Training dataset we split to the first part return paths[0:math.ceil(len(paths) * (1 - validation_split))] else: + # Validation dataset we split to the second part return paths[len(paths) - round(len(paths) * validation_split):] @@ -1822,6 +1833,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_training_dataset: bool, batch_size: int, resolution, network_multiplier: float, @@ -1843,6 +1855,7 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split @@ -1952,6 +1965,9 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2046,7 +2062,8 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeats.") + images_split_name = "train" if self.is_training_dataset else "validation" + logger.info(f"{num_train_images} {images_split_name} images with repeats.") self.num_train_images = num_train_images @@ -2411,8 +2428,12 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -4586,7 +4607,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) return args @@ -5880,55 +5900,35 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - Get a random timestep between the min and max timesteps - Can error (NotImplementedError) if the loss type is not supported - """ - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timesteps = timesteps.repeat(batch_size).to(device) - elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device) - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - return typing.cast(torch.IntTensor, timesteps) +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + return timesteps -def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]: - """ - Calculate the Huber convolution (huber_c) value - Huber loss is a loss function used in robust regression, that is less sensitive - to outliers in data than the squared error loss. - https://en.wikipedia.org/wiki/Huber_loss - """ - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000) - huber_c = math.exp(-alpha * timesteps.item()) - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = args.huber_c - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - elif args.loss_type == "l2": +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - return huber_c + return result -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor): +def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: """ Apply noise modifications like noise offset and multires noise """ @@ -5964,27 +5964,44 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep # Sample a random timestep for each image - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, device) + timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]: - """ - Unified noise, noisy_latents, timesteps and huber loss convolution calculations - """ - batch_size = latents.shape[0] +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) + if args.multires_noise_iterations: + noise = custom_train_functions.pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get("num_train_timesteps", 1000) if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - # A random timestep for each image in the batch - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, latents.device) - huber_c = get_huber_c(args, noise_scheduler, timesteps) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) - noise = make_noise(args, latents) - noisy_latents = get_noisy_latents(args, noise, noise_scheduler, latents, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: @@ -6015,6 +6032,8 @@ def conditional_loss( elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -6022,6 +6041,8 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/train_network.py b/train_network.py index 7d064d21..f870734f 100644 --- a/train_network.py +++ b/train_network.py @@ -205,10 +205,10 @@ class NetworkTrainer: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor: return vae.encode(images).latent_dist.sample() - def shift_scale_latents(self, args, latents): + def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor: return latents * self.vae_scale_factor def get_noise_pred_and_target( @@ -280,7 +280,7 @@ class NetworkTrainer: return noise_pred, target, timesteps, None - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -317,20 +317,21 @@ class NetworkTrainer: # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample()) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)) - latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor) + latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) + + latents = self.shift_scale_latents(args, latents) text_encoder_conds = [] @@ -384,22 +385,36 @@ class NetworkTrainer: total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in chosen_timesteps_list: - fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) + for fixed_timesteps in chosen_timesteps_list: + fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) # Predict the noise residual # and add noise to the latents # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep) + noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + fixed_timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timestep) + target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) else: target = noise @@ -418,7 +433,7 @@ class NetworkTrainer: accelerator, unet, noisy_latents, - timesteps, + fixed_timesteps, text_encoder_conds, batch, weight_dtype, @@ -427,7 +442,8 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): @@ -436,14 +452,7 @@ class NetworkTrainer: loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler) + loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) total_loss += loss @@ -526,8 +535,12 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) + + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly + train_util.debug_dataset(val_dataset_group) return if len(train_dataset_group) == 0: logger.error( @@ -753,10 +766,6 @@ class NetworkTrainer: # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) - # Not for sure here. - # if val_dataset_group is not None: - # val_dataset_group.set_max_train_steps(args.max_train_steps) - # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1304,7 +1313,7 @@ class NetworkTrainer: clean_memory_on_device(accelerator.device) for epoch in range(epoch_to_start, num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) @@ -1324,7 +1333,7 @@ class NetworkTrainer: continue with accelerator.accumulate(training_model): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1384,7 +1393,8 @@ class NetworkTrainer: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) # VALIDATION PER STEP should_validate = (args.validation_every_n_step is not None @@ -1401,7 +1411,7 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1409,10 +1419,12 @@ class NetworkTrainer: if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) logs = {"loss/average_val_loss": val_loss_recorder.moving_average} - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -1427,7 +1439,7 @@ class NetworkTrainer: ) for val_step, batch in enumerate(val_dataloader): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -1437,22 +1449,26 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) if len(val_dataloader) > 0 and is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) accelerator.wait_for_everyone()