From 5b19bda85c2ce01e4a1c7f324b7ef14bffed3315 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:35:46 -0500 Subject: [PATCH 01/87] Add validation loss --- library/train_util.py | 4 ++ train_network.py | 117 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index cc9ac455..e26f3979 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4736,6 +4736,10 @@ class collator_class: else: dataset = self.dataset + # If we split a dataset we will get a Subset + if type(dataset) is torch.utils.data.Subset: + dataset = dataset.dataset + # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) diff --git a/train_network.py b/train_network.py index d50916b7..58767b6f 100644 --- a/train_network.py +++ b/train_network.py @@ -345,8 +345,21 @@ class NetworkTrainer: # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + if args.validation_ratio > 0.0: + train_ratio = 1 - args.validation_ratio + validation_ratio = args.validation_ratio + train, val = torch.utils.data.random_split( + train_dataset_group, + [train_ratio, validation_ratio] + ) + print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") + print(f"train images: {len(train)}, validation images: {len(val)}") + else: + train = train_dataset_group + val = [] + train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, + train, batch_size=1, shuffle=True, collate_fn=collator, @@ -354,6 +367,15 @@ class NetworkTrainer: persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val, + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -711,6 +733,8 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -752,6 +776,8 @@ class NetworkTrainer: network.on_epoch_start(text_encoder, unet) + # TRAINING + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): @@ -877,6 +903,87 @@ class NetworkTrainer: if global_step >= args.max_train_steps: break + # VALIDATION + + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + for val_step, batch in enumerate(val_dataloader): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + current_loss = loss.detach().item() + + val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + + if len(val_dataloader) > 0: + avr_loss: float = val_loss_recorder.moving_average + + if args.logging_dir is not None: + logs = {"loss/validation": avr_loss} + accelerator.log(logs, step=epoch + 1) + + if args.logging_dir is not None: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + + parser.add_argument( + "--validation_ratio", + type=float, + default=0.0, + help="Ratio for validation images out of the training dataset" + ) + return parser From 33c311ed19821c9be7094ba89371777d7478b028 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:37:37 -0500 Subject: [PATCH 02/87] new ratio code --- train_network.py | 48 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 58767b6f..967c95fb 100644 --- a/train_network.py +++ b/train_network.py @@ -345,10 +345,48 @@ class NetworkTrainer: # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + def get_indices_without_reg(dataset: torch.utils.data.Dataset): + return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] + + from typing import Sequence, Union + from torch._utils import _accumulate + import warnings + from torch.utils.data.dataset import Subset + + def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): + indices = get_indices_without_reg(dataset) + random.shuffle(indices) + + subset_lengths = [] + + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int(math.floor(len(indices) * frac)) + subset_lengths.append(n_items_in_split) + + remainder = len(indices) - sum(subset_lengths) + + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn(f"Length of split at index {i} is 0. " + f"This might result in an empty dataset.") + + if sum(lengths) != len(indices): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] + + if args.validation_ratio > 0.0: train_ratio = 1 - args.validation_ratio validation_ratio = args.validation_ratio - train, val = torch.utils.data.random_split( + train, val = random_split( train_dataset_group, [train_ratio, validation_ratio] ) @@ -358,6 +396,8 @@ class NetworkTrainer: train = train_dataset_group val = [] + + train_dataloader = torch.utils.data.DataLoader( train, batch_size=1, @@ -898,7 +938,7 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -973,13 +1013,11 @@ class NetworkTrainer: loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) if len(val_dataloader) > 0: - avr_loss: float = val_loss_recorder.moving_average - if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation": avr_loss} accelerator.log(logs, step=epoch + 1) From 3de9e6c443037abf99832d1be60f4fc9c0d67b8c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 01:45:23 -0500 Subject: [PATCH 03/87] Add validation split of datasets --- library/config_util.py | 141 ++++++++++++++++++++++++++--------------- library/train_util.py | 26 ++++++++ train_network.py | 67 ++++---------------- 3 files changed, 126 insertions(+), 108 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7..1bf7ed95 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -85,6 +85,8 @@ class BaseDatasetParams: max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -200,6 +202,8 @@ class ConfigSanitizer: "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } @@ -427,64 +431,89 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) - - if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset else: - info += "\n" + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: + # print info + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") - print(info) + print(info) + + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) # make buckets first because it determines the length of dataset # and set the same seed for all datasets @@ -494,7 +523,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index e26f3979..ba37ec13 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -123,6 +123,22 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] + + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: self.image_key: str = image_key @@ -1314,6 +1330,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1324,12 +1341,18 @@ class DreamBoothDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset, ) -> None: super().__init__(tokenizer, max_token_length, resolution, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed + self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1382,6 +1405,9 @@ class DreamBoothDataset(BaseDataset): return [], [] img_paths = glob_images(subset.image_dir, "*") + + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index 967c95fb..97ecfe7b 100644 --- a/train_network.py +++ b/train_network.py @@ -189,10 +189,11 @@ class NetworkTrainer: } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -212,6 +213,10 @@ class NetworkTrainer: assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" self.assert_extra_args(args, train_dataset_group) @@ -264,6 +269,9 @@ class NetworkTrainer: vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -345,61 +353,8 @@ class NetworkTrainer: # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - def get_indices_without_reg(dataset: torch.utils.data.Dataset): - return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] - - from typing import Sequence, Union - from torch._utils import _accumulate - import warnings - from torch.utils.data.dataset import Subset - - def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): - indices = get_indices_without_reg(dataset) - random.shuffle(indices) - - subset_lengths = [] - - for i, frac in enumerate(lengths): - if frac < 0 or frac > 1: - raise ValueError(f"Fraction at index {i} is not between 0 and 1") - n_items_in_split = int(math.floor(len(indices) * frac)) - subset_lengths.append(n_items_in_split) - - remainder = len(indices) - sum(subset_lengths) - - for i in range(remainder): - idx_to_add_at = i % len(subset_lengths) - subset_lengths[idx_to_add_at] += 1 - - lengths = subset_lengths - for i, length in enumerate(lengths): - if length == 0: - warnings.warn(f"Length of split at index {i} is 0. " - f"This might result in an empty dataset.") - - if sum(lengths) != len(indices): - raise ValueError("Sum of input lengths does not equal the length of the input dataset!") - - return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] - - - if args.validation_ratio > 0.0: - train_ratio = 1 - args.validation_ratio - validation_ratio = args.validation_ratio - train, val = random_split( - train_dataset_group, - [train_ratio, validation_ratio] - ) - print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") - print(f"train images: {len(train)}, validation images: {len(val)}") - else: - train = train_dataset_group - val = [] - - - train_dataloader = torch.utils.data.DataLoader( - train, + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, @@ -408,7 +363,7 @@ class NetworkTrainer: ) val_dataloader = torch.utils.data.DataLoader( - val, + val_dataset_group if val_dataset_group is not None else [], shuffle=False, batch_size=1, collate_fn=collator, From a93c524b3a0e5c80a58c1317211dec93b6c137a7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 02:07:39 -0500 Subject: [PATCH 04/87] Update args to validation_seed and validation_split --- train_network.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 97ecfe7b..f9e5debd 100644 --- a/train_network.py +++ b/train_network.py @@ -1099,12 +1099,17 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_ratio", + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", type=float, default=0.0, - help="Ratio for validation images out of the training dataset" + help="Split for validation images out of the training dataset" ) return parser From c89252101e8e8bd74cb3ab09ae33b548fd828e15 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:27:36 -0500 Subject: [PATCH 05/87] Add process_batch for train_network --- train_network.py | 211 ++++++++++++++++++----------------------------- 1 file changed, 82 insertions(+), 129 deletions(-) diff --git a/train_network.py b/train_network.py index f9e5debd..387b94b1 100644 --- a/train_network.py +++ b/train_network.py @@ -130,6 +130,75 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + return loss + + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -777,71 +846,8 @@ class NetworkTrainer: current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = True + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -893,7 +899,7 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs) + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -905,80 +911,27 @@ class NetworkTrainer: with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + if len(val_dataloader) > 0: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation": avr_loss} + logs = {"loss/validation_average": avr_loss} accelerator.log(logs, step=epoch + 1) if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + # logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From e545fdfd9affabff83f8bd2e7680369bb34dd301 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:56:36 -0500 Subject: [PATCH 06/87] Removed/cleanup a line --- train_network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train_network.py b/train_network.py index 387b94b1..a4125e9f 100644 --- a/train_network.py +++ b/train_network.py @@ -930,7 +930,6 @@ class NetworkTrainer: if args.logging_dir is not None: - # logs = {"loss/epoch": loss_recorder.moving_average} logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) From 9c591bdb12ce663b3fe9e91c0963d2cf71461bad Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:58:20 -0500 Subject: [PATCH 07/87] Remove unnecessary subset line from collate --- library/train_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba37ec13..1979207b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4762,10 +4762,6 @@ class collator_class: else: dataset = self.dataset - # If we split a dataset we will get a Subset - if type(dataset) is torch.utils.data.Subset: - dataset = dataset.dataset - # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) From 569ca72fc4cda2f4ce30e43b1c62989e79e3c3b3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Nov 2023 11:59:30 -0500 Subject: [PATCH 08/87] Set grad enabled if is_train and train_text_encoder We only want to be enabling grad if we are training. --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index a4125e9f..edd3ff94 100644 --- a/train_network.py +++ b/train_network.py @@ -145,7 +145,7 @@ class NetworkTrainer: latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( From b558a5b73d07a7e15ad90d9d15c2b55c5d2b3d61 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH 09/87] val --- library/config_util.py | 170 ++++++++++++++++++++++------------------- library/train_util.py | 22 ++++++ train_network.py | 133 ++++++++++++++++++++++++++++++-- 3 files changed, 237 insertions(+), 88 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index fc4b3617..17fc1781 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -98,7 +98,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False - + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -109,8 +110,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -222,8 +222,11 @@ class ConfigSanitizer: "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + } # options handled by argparse but not handled by user config @@ -460,100 +463,107 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent( - f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") else: info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """ - ), - " ", - ) + info += indent(dedent(f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") - if is_dreambooth: - info += indent( - dedent( - f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n""" - ), - " ", - ) - elif not is_controlnet: - info += indent( - dedent( - f"""\ - metadata_file: {subset.metadata_file} - \n""" - ), - " ", - ) + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") - logger.info(f'{info}') + print(info) + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) - - + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) + def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split("_") diff --git a/library/train_util.py b/library/train_util.py index d2b69edb..753539e0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -134,6 +134,20 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -1360,6 +1374,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1371,12 +1386,17 @@ class DreamBoothDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1429,6 +1449,8 @@ class DreamBoothDataset(BaseDataset): return [], [] img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index e0fa6945..db7000e8 100644 --- a/train_network.py +++ b/train_network.py @@ -136,6 +136,67 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -196,11 +257,12 @@ class NetworkTrainer: } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - + val_dataset_group = None # placeholder until validation dataset supported for arbitrary + current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None @@ -219,7 +281,11 @@ class NetworkTrainer: assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -271,6 +337,9 @@ class NetworkTrainer: vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -360,6 +429,15 @@ class NetworkTrainer: num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -707,6 +785,8 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -755,7 +835,8 @@ class NetworkTrainer: current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - + + is_train = True with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -780,7 +861,7 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -810,7 +891,7 @@ class NetworkTrainer: t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -844,7 +925,7 @@ class NetworkTrainer: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -898,14 +979,38 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1045,6 +1150,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) return parser From 78cfb01922ff97bbc62ff12a4d69eaaa2d89d7c1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 18:55:48 +0800 Subject: [PATCH 10/87] improve --- library/config_util.py | 252 ++++++++++++++++++++++++++++++----------- train_network.py | 67 +++++++---- 2 files changed, 230 insertions(+), 89 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 17fc1781..d198cee3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ from .train_util import ( DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -60,6 +65,8 @@ class BaseSubsetParams: caption_separator: str = (",",) keep_tokens: int = 0 keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -181,6 +188,8 @@ class ConfigSanitizer: "shuffle_caption": bool, "keep_tokens": int, "keep_tokens_separator": str, + "secondary_separator": str, + "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), "caption_prefix": str, @@ -247,9 +256,10 @@ class ConfigSanitizer: } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -361,7 +371,9 @@ class ConfigSanitizer: return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -447,7 +459,6 @@ class BlueprintGenerator: return default_value - def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -467,7 +478,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets.append(dataset) val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - + for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split <= 0.0: continue @@ -485,75 +496,174 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): - info = "" - for i, dataset in enumerate(_datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) else: info += "\n" + for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) - if is_dreambooth: - info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) - print(info) + logger.info(f'{info}') - print_info(datasets) + # print validation info + info = "" + for i, dataset in enumerate(val_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f'{info}') - if len(val_datasets) > 0: - print("Validation dataset") - print_info(val_datasets) - # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - + for i, dataset in enumerate(val_datasets): print(f"[Validation Dataset {i}]") dataset.make_buckets() @@ -562,8 +672,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu return ( DatasetGroup(datasets), DatasetGroup(val_datasets) if val_datasets else None - ) - + ) + def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split("_") @@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -675,13 +789,13 @@ if __name__ == "__main__": train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -690,10 +804,10 @@ if __name__ == "__main__": logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") diff --git a/train_network.py b/train_network.py index db7000e8..d3e34eb7 100644 --- a/train_network.py +++ b/train_network.py @@ -44,6 +44,7 @@ from library.utils import setup_logging, add_logging_arguments setup_logging() import logging +import itertools logger = logging.getLogger(__name__) @@ -438,6 +439,7 @@ class NetworkTrainer: num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -979,23 +981,24 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - - if global_step % 25 == 0: - if len(val_dataloader) > 0: - print("Validating バリデーション処理...") - - with torch.no_grad(): - val_dataloader_iter = iter(val_dataloader) - batch = next(val_dataloader_iter) - is_train = False - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - - current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.validation_every_n_step is not None: + if global_step % (args.validation_every_n_step) == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + for val_step in min(len(val_dataloader), args.validation_batches): + is_train = False + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_current": current_loss} + logs = {"loss/avr_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1005,12 +1008,24 @@ class NetworkTrainer: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - if len(val_dataloader) > 0: - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) - + if args.validation_every_n_step is None: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + for val_step in min(len(val_dataloader), args.validation_batches): + is_train = False + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/val_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser: default=0.0, help="Split for validation images out of the training dataset" ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of steps for counting validation loss. By default, validation per epoch is performed" + ) + parser.add_argument( + "--validation_batches", + type=int, + default=1, + help="Number of val steps for counting validation loss. By default, validation one batch is performed" + ) return parser From 923b761ce3622a3132bf0db7768e6b97df21c607 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:01:40 +0800 Subject: [PATCH 11/87] Update train_network.py --- train_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index d3e34eb7..82110066 100644 --- a/train_network.py +++ b/train_network.py @@ -988,6 +988,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1013,6 +1014,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1186,8 +1188,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--validation_batches", type=int, - default=1, - help="Number of val steps for counting validation loss. By default, validation one batch is performed" + default=None, + help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" ) return parser From 47359b8fac9602415f56b1f7e3f25a00255a1d78 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:17:40 +0800 Subject: [PATCH 12/87] Update train_network.py --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 82110066..d549378c 100644 --- a/train_network.py +++ b/train_network.py @@ -989,7 +989,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1015,7 +1015,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) From a51723cc2a3dd50b45e60945f97bc5adfe753d1f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:42:58 +0800 Subject: [PATCH 13/87] fix timesteps --- train_network.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index d549378c..f0f27ea7 100644 --- a/train_network.py +++ b/train_network.py @@ -141,7 +141,6 @@ class NetworkTrainer: total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] - with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -174,16 +173,17 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - for timesteps in timesteps_list: - # Predict the noise residual + + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) @@ -988,7 +988,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) @@ -999,7 +999,7 @@ class NetworkTrainer: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/avr_val_loss": avr_loss} + logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1014,7 +1014,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) From 7d84ac2177a603e9aa6834fd1c0ee19a463eb5a0 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:41:51 +0800 Subject: [PATCH 14/87] only use train subset to val --- library/config_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index d198cee3..1a6cef97 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -492,7 +492,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From befbec5335ed1f8018d22b65993b376571ea2989 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:47:04 +0800 Subject: [PATCH 15/87] Update train_network.py --- train_network.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index f0f27ea7..cbc107b6 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in timesteps_list: + for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -184,16 +184,16 @@ class NetworkTrainer: noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss average_loss = total_loss / len(timesteps_list) return average_loss @@ -985,7 +985,7 @@ class NetworkTrainer: if args.validation_every_n_step is not None: if global_step % (args.validation_every_n_step) == 0: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -994,10 +994,12 @@ class NetworkTrainer: batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) avr_loss: float = val_loss_recorder.moving_average logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) @@ -1011,7 +1013,7 @@ class NetworkTrainer: if args.validation_every_n_step is None: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -1025,7 +1027,7 @@ class NetworkTrainer: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/val_epoch_average": avr_loss} + logs = {"loss/epoch_val_average": avr_loss} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 63e58f78e3df7608045071cdc247bb26bd19a333 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:15:55 +0800 Subject: [PATCH 16/87] Update train_network.py --- train_network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index cbc107b6..82d72df2 100644 --- a/train_network.py +++ b/train_network.py @@ -178,8 +178,7 @@ class NetworkTrainer: with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] - timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype From a6c41c6bea0465112c7bd472dff68b7e8ecea46e Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:23:48 +0800 Subject: [PATCH 17/87] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 82d72df2..6eefdb2b 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -988,7 +988,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1016,7 +1016,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) From bd7e2295b7c4d1444a9e844309e1685cb29c6961 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 17:54:21 +0800 Subject: [PATCH 18/87] fix --- train_network.py | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/train_network.py b/train_network.py index 6eefdb2b..128690fb 100644 --- a/train_network.py +++ b/train_network.py @@ -981,20 +981,19 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if args.validation_every_n_step is not None: - if global_step % (args.validation_every_n_step) == 0: - if len(val_dataloader) > 0: + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} @@ -1009,25 +1008,6 @@ class NetworkTrainer: if args.logging_dir is not None: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - - if args.validation_every_n_step is None: - if len(val_dataloader) > 0: - print(f"\nValidating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/epoch_val_average": avr_loss} - accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -1184,14 +1164,14 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_every_n_step", type=int, default=None, - help="Number of steps for counting validation loss. By default, validation per epoch is performed" + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" ) parser.add_argument( - "--validation_batches", + "--max_validation_steps", type=int, default=None, - help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" - ) + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser From d05965dbadf430dab6a05f171292f6d2077ec946 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 18:33:51 +0800 Subject: [PATCH 19/87] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 864bfd70..cc9fcbbe 100644 --- a/train_network.py +++ b/train_network.py @@ -987,8 +987,8 @@ class NetworkTrainer: accelerator.log(logs, step=global_step) if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: - print(f"\nValidating バリデーション処理...") + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) @@ -998,7 +998,7 @@ class NetworkTrainer: loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} From b5e8045df40ed4a437492ed2b6ea6d5be7282080 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Sat, 16 Mar 2024 11:51:11 +0800 Subject: [PATCH 20/87] fix control net --- library/config_util.py | 6 ++++-- library/train_util.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index ec6ef4b2..0da0b143 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -491,8 +491,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + if subset_klass == DreamBoothSubset: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + else: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) diff --git a/library/train_util.py b/library/train_util.py index 89297962..ae7968d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1816,6 +1816,7 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1826,6 +1827,8 @@ class ControlNetDataset(BaseDataset): max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + validation_split: float, + validation_seed: Optional[int], debug_dataset: float, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1860,6 +1863,7 @@ class ControlNetDataset(BaseDataset): self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + is_train, batch_size, tokenizer, max_token_length, @@ -1871,6 +1875,8 @@ class ControlNetDataset(BaseDataset): bucket_reso_steps, bucket_no_upscale, 1.0, + validation_split, + validation_seed, debug_dataset, ) @@ -1878,7 +1884,10 @@ class ControlNetDataset(BaseDataset): self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -1911,8 +1920,8 @@ class ControlNetDataset(BaseDataset): [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] ) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + #assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + #assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS From 36d4023431d10718b00673d5ba34f426690c62de Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:39:17 +0800 Subject: [PATCH 21/87] Update config_util.py --- library/config_util.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a7e0024e..c6667690 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -498,10 +498,21 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - if subset_klass == DreamBoothSubset: - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] - else: - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + + subsets = [] + for subset_blueprint in dataset_blueprint.subsets: + subset_blueprint.params.num_repeats = 1 + subset_blueprint.params.color_aug = False + subset_blueprint.params.flip_aug = False + subset_blueprint.params.random_crop = False + subset_blueprint.params.random_crop = None + subset_blueprint.params.caption_dropout_rate = 0.0 + subset_blueprint.params.caption_dropout_every_n_epochs = 0 + subset_blueprint.params.caption_tag_dropout_rate = 0.0 + subset_blueprint.params.token_warmup_step = 0 + if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: + subsets.append(subset_klass(**asdict(subset_blueprint.params))) + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 229c5a38ef4e93e2023d748b4fa1588d490340ad Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:45:49 +0800 Subject: [PATCH 22/87] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 832be75d..b143e85a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3123,7 +3123,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", From 3b251b758dae6e4f11e0bbc7e544dc9542c836ff Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:50:32 +0800 Subject: [PATCH 23/87] Update config_util.py --- library/config_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index c6667690..8f01e1f6 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -510,8 +510,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.caption_dropout_every_n_epochs = 0 subset_blueprint.params.caption_tag_dropout_rate = 0.0 subset_blueprint.params.token_warmup_step = 0 - if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: - subsets.append(subset_klass(**asdict(subset_blueprint.params))) + + if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): + subset = subset_klass(**asdict(subset_blueprint.params)) + subsets.append(subset) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 459b12539b0ae1a92da98e38568ea0a61db1e89f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:52:14 +0800 Subject: [PATCH 24/87] Update config_util.py --- library/config_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 8f01e1f6..6f243aac 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -512,8 +512,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.token_warmup_step = 0 if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): - subset = subset_klass(**asdict(subset_blueprint.params)) - subsets.append(subset) + subsets.append(subset_klass(**asdict(subset_blueprint.params))) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 89ad69b6a0d35791627cb58630a711befc6bb3b5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 08:42:31 +0800 Subject: [PATCH 25/87] Update train_util.py --- library/train_util.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b143e85a..8bf6823b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1511,17 +1511,6 @@ class DreamBoothDataset(BaseDataset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info if use_cached_info_for_subset: @@ -1545,6 +1534,8 @@ class DreamBoothDataset(BaseDataset): # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) sizes = [None] * len(img_paths) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") From fde8026c2d92fe4991927eed6fa1ff373e8d38d2 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:29:26 +0800 Subject: [PATCH 26/87] Update config_util.py --- library/config_util.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 6f243aac..a1b02bd1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -636,19 +636,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} - num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, """ @@ -688,7 +680,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - print(f"[Validation Dataset {i}]") + logger.info(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) From 31507b9901d1d9ab65ba79ebd747b7f35c7e0fc1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:15:21 +0800 Subject: [PATCH 27/87] Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results) --- train_network.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index 2a3a4482..4a5940cd 100644 --- a/train_network.py +++ b/train_network.py @@ -135,7 +135,7 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] @@ -153,7 +153,7 @@ class NetworkTrainer: latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -173,7 +173,7 @@ class NetworkTrainer: # with noise offset and/or multires noise if specified for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) @@ -189,6 +189,7 @@ class NetworkTrainer: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss @@ -885,8 +886,7 @@ class NetworkTrainer: for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) - is_train = True + on_step_start(text_encoder, unet) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: @@ -911,7 +911,7 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -941,7 +941,7 @@ class NetworkTrainer: t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1040,10 +1040,9 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) From 1db495127f25c1b17694780f635a4760b4e345d0 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 14:53:46 +0800 Subject: [PATCH 28/87] Update train_db.py --- train_db.py | 128 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 4 deletions(-) diff --git a/train_db.py b/train_db.py index 1de504ed..9f8ec777 100644 --- a/train_db.py +++ b/train_db.py @@ -2,7 +2,6 @@ # XXX dropped option: fine_tune import argparse -import itertools import math import os from multiprocessing import Value @@ -41,11 +40,73 @@ from library.utils import setup_logging, add_logging_arguments setup_logging() import logging +import itertools logger = logging.getLogger(__name__) # perlin_noise, +def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + with torch.set_grad_enabled(False), accelerator.autocast(): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss def train(args): train_util.verify_training_args(args) @@ -81,9 +142,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -148,6 +210,9 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -195,6 +260,15 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -296,6 +370,8 @@ def train(args): train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -427,12 +503,33 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -515,7 +612,30 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser From 68162172ebf9afa21ad526fc833fcc04f74aeb5f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:03:56 +0800 Subject: [PATCH 29/87] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index 9f8ec777..e98434db 100644 --- a/train_db.py +++ b/train_db.py @@ -209,10 +209,10 @@ def train(args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") if val_dataset_group is not None: print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() From 96eb74f0cba3253ba29c8e87d7479c355916cca5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:06:05 +0800 Subject: [PATCH 30/87] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index e98434db..80fdff3e 100644 --- a/train_db.py +++ b/train_db.py @@ -210,8 +210,8 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) From b9bdd101296b8dc3c60b25e31d04d39b57eaee71 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:11:26 +0800 Subject: [PATCH 31/87] Update train_network.py --- train_network.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/train_network.py b/train_network.py index 4a5940cd..d7b24dae 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From 3d68754defde57b10f96d9c934dd78bf25c39235 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:15:42 +0800 Subject: [PATCH 32/87] Update train_db.py --- train_db.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/train_db.py b/train_db.py index 80fdff3e..800a157b 100644 --- a/train_db.py +++ b/train_db.py @@ -503,28 +503,26 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) - if global_step >= args.max_train_steps: break From a593e837f36b6299101dc85a367c0986501ecc0a Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:17:30 +0800 Subject: [PATCH 33/87] Update train_network.py --- train_network.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/train_network.py b/train_network.py index d7b24dae..7d913463 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From f6dbf7c419bbcf2e51c82a6bffa8d30cad2e3512 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:18:53 +0800 Subject: [PATCH 34/87] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index 7d913463..fa6407ee 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From aa850aa531b0e396b6f2fbd68cd1e6f1319d1d0b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:34:20 +0800 Subject: [PATCH 35/87] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index fa6407ee..938e4193 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From cdb2d9c516fbffe0faa9788b8174e5d418fb766b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:36:34 +0800 Subject: [PATCH 36/87] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 938e4193..e10c17c0 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ class NetworkTrainer: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss - + average_loss = total_loss / len(timesteps_list) return average_loss From 3028027e074c891f33d45fff27068b490a408329 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 4 Oct 2024 16:41:41 +0800 Subject: [PATCH 37/87] Update train_network.py --- train_network.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/train_network.py b/train_network.py index e10c17c0..c0239a6d 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From dece2c388f1c39e7baca201b4bf4e61d9f67a219 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 4 Oct 2024 16:43:07 +0800 Subject: [PATCH 38/87] Update train_db.py --- train_db.py | 148 ++++++++++++++++++++++++++-------------------------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/train_db.py b/train_db.py index 800a157b..2c17e521 100644 --- a/train_db.py +++ b/train_db.py @@ -46,67 +46,67 @@ logger = logging.getLogger(__name__) # perlin_noise, def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + with torch.set_grad_enabled(False), accelerator.autocast(): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss - average_loss = total_loss / len(timesteps_list) - return average_loss + average_loss = total_loss / len(timesteps_list) + return average_loss def train(args): train_util.verify_training_args(args) @@ -210,8 +210,8 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -503,25 +503,25 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From 3cb8cb2d4fd697a49135193ac0873204e0139e62 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 9 Dec 2024 15:20:04 -0500 Subject: [PATCH 39/87] Prevent git credentials from leaking into other actions --- .github/workflows/tests.yml | 4 ++++ .github/workflows/typos.yml | 3 +++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 672a657b..2eddedc7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,10 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 87ebdf89..f53cda21 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,6 +18,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false - name: typos-action uses: crate-ci/typos@v1.28.1 From 8e378cf03df645cef897a342559dc5fa7f66a35d Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Wed, 11 Dec 2024 19:43:44 +0900 Subject: [PATCH 40/87] add RAdamScheduleFree support --- library/train_util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a35388fe..72b5b24d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4887,7 +4887,11 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - if optimizer_type == "AdamWScheduleFree".lower(): + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): optimizer_class = sf.AdamWScheduleFree logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") elif optimizer_type == "SGDScheduleFree".lower(): From e89653975ddf429cdf0c0fd268da0a5a3e8dba1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Dec 2024 19:39:47 +0900 Subject: [PATCH 41/87] update requirements.txt and README to include RAdamScheduleFree optimizer support --- README.md | 6 ++++++ requirements.txt | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 67836ddf..bfb22bcf 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 15, 2024: + +- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu! + - Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`. + - Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler. + Dec 7, 2024: - The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! diff --git a/requirements.txt b/requirements.txt index 0dd1c69c..e0091749 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 -schedulefree==1.2.7 +schedulefree==1.4 tensorboard safetensors==0.4.4 # gradio==3.16.2 From 05bb9183fae18c62a1730fe5060f80c0b99a21f3 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 16:47:59 +0800 Subject: [PATCH 42/87] Add Validation loss for LoRA training --- library/config_util.py | 78 +++++++++++++++++++++++- library/train_util.py | 54 ++++++++++++++++- train_network.py | 131 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 6 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 12d0be17..a57cd36f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -73,6 +73,8 @@ class BaseSubsetParams: token_warmup_min: int = 1 token_warmup_step: float = 0 custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -102,6 +104,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass @@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + # print info info = "" for i, dataset in enumerate(datasets): @@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(f"{info}") + if len(val_datasets) > 0: + info = "" + + for i, dataset in enumerate(val_datasets): + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Validation Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + """ + ), + " ", + ) + + logger.info(f"{info}") + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..a3fa98e9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -145,6 +145,17 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" +def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]: + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -397,6 +408,8 @@ class BaseSubset: token_warmup_min: int, token_warmup_step: Union[float, int], custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -424,6 +437,9 @@ class BaseSubset: self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -453,6 +469,8 @@ class DreamBoothSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -478,6 +496,8 @@ class DreamBoothSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -518,6 +538,8 @@ class FineTuningSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -543,6 +565,8 @@ class FineTuningSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -579,6 +603,8 @@ class ControlNetSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -604,6 +630,8 @@ class ControlNetSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -1799,6 +1827,9 @@ class DreamBoothDataset(BaseDataset): bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1808,6 +1839,9 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_train = is_train + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1992,6 +2026,9 @@ class DreamBoothDataset(BaseDataset): ) continue + if self.is_train == False: + img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed) + if subset.is_reg: num_reg_images += subset.num_repeats * len(img_paths) else: @@ -2009,7 +2046,11 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + if self.is_train: + logger.info(f"{num_train_images} train images with repeating.") + else: + logger.info(f"{num_train_images} validation images with repeating.") + self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") @@ -2050,6 +2091,9 @@ class FineTuningDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2276,6 +2320,9 @@ class ControlNetDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: float, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2324,6 +2371,9 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale, 1.0, debug_dataset, + is_train, + validation_seed, + validation_split, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - + if optimizer_type == "RAdamScheduleFree".lower(): optimizer_class = sf.RAdamScheduleFree logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") diff --git a/train_network.py b/train_network.py index 5e82b307..776feaf7 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import json from multiprocessing import Value from typing import Any, List import toml +import itertools from tqdm import tqdm @@ -114,7 +115,7 @@ class NetworkTrainer: ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): + ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) @@ -373,10 +374,11 @@ class NetworkTrainer: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -398,6 +400,11 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する @@ -444,6 +451,8 @@ class NetworkTrainer: vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -459,6 +468,8 @@ class NetworkTrainer: if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -567,6 +578,8 @@ class NetworkTrainer: # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -580,6 +593,17 @@ class NetworkTrainer: persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + batch_size=1, + shuffle=False, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + cyclic_val_dataloader = itertools.cycle(val_dataloader) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -592,6 +616,10 @@ class NetworkTrainer: # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # Not for sure here. + # if val_dataset_group is not None: + # val_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1064,7 +1092,11 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() + # val_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -1308,6 +1340,77 @@ class NetworkTrainer: ) accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps): + accelerator.print("\nValidating バリデーション処理...") + + total_loss = 0.0 + + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + batch = next(cyclic_val_dataloader) + + timesteps_list = [10, 350, 500, 650, 990] + + val_loss = 0.0 + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") + timesteps = timesteps.long().to(latents.device) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(False), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + val_loss += loss / len(timesteps_list) + + total_loss += val_loss.detach().item() + + current_val_loss = total_loss / validation_steps + # val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss) + + if len(accelerator.trackers) > 0: + logs = {"loss/current_val_loss": current_val_loss} + accelerator.log(logs, step=global_step) + + # avr_loss: float = val_loss_recorder.moving_average + # logs = {"loss/average_val_loss": avr_loss} + # accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed / 検証シード" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 62164e57925125ed6268983ffa441f1ffecc0e6d Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 17:28:05 +0800 Subject: [PATCH 43/87] Change val loss calculate method --- train_network.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index 776feaf7..5fd1b212 100644 --- a/train_network.py +++ b/train_network.py @@ -1383,16 +1383,20 @@ class NetworkTrainer: else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + # if weighting is not None: + # loss = loss * weighting + # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + # loss = apply_masked_loss(loss, batch) + # loss = loss.mean([1, 2, 3]) # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 64bd5317dc9cb39d69ab7728f36b03157c9b341f Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:42:15 +0800 Subject: [PATCH 44/87] Split val latents/batch and pick up val latents shape size which equal to training batch. --- train_network.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/train_network.py b/train_network.py index 5fd1b212..6bce9e96 100644 --- a/train_network.py +++ b/train_network.py @@ -1349,7 +1349,27 @@ class NetworkTrainer: with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): - batch = next(cyclic_val_dataloader) + + while True: + val_batch = next(cyclic_val_dataloader) + + if "latents" in val_batch and val_batch["latents"] is not None: + val_latents = val_batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + val_latents = self.encode_images_to_latents(args, accelerator, vae, val_batch["images"].to(vae_dtype)) + val_latents = val_latents.to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(val_latents)): + accelerator.print("NaN found in validation latents, replacing with zeros") + val_latents = torch.nan_to_num(val_latents, 0, out=val_latents) + + val_latents = self.shift_scale_latents(args, val_latents) + + if val_latents.shape == latents.shape: + break timesteps_list = [10, 350, 500, 650, 990] @@ -1357,13 +1377,13 @@ class NetworkTrainer: for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + noise = torch.randn_like(val_latents, device=val_latents.device) + b_size = val_latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(latents.device) + timesteps = timesteps.long().to(val_latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1373,27 +1393,16 @@ class NetworkTrainer: noisy_latents.requires_grad_(False), timesteps, text_encoder_conds, - batch, + val_batch, weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(val_latents, noise, timesteps) else: target = noise - # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - # if weighting is not None: - # loss = loss * weighting - # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - # loss = apply_masked_loss(loss, batch) - # loss = loss.mean([1, 2, 3]) - - # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From cb89e0284e1a25b41401861107159e6b943ee387 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:57:04 +0800 Subject: [PATCH 45/87] Change val latent loss compare --- train_network.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 6bce9e96..7276d5dc 100644 --- a/train_network.py +++ b/train_network.py @@ -1350,6 +1350,8 @@ class NetworkTrainer: validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + val_latents = None + while True: val_batch = next(cyclic_val_dataloader) @@ -1371,19 +1373,22 @@ class NetworkTrainer: if val_latents.shape == latents.shape: break + if val_latents is not None: + del val_latents + timesteps_list = [10, 350, 500, 650, 990] val_loss = 0.0 for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(val_latents, device=val_latents.device) - b_size = val_latents.shape[0] + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(val_latents.device) + timesteps = timesteps.long().to(latents.device) - noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1399,7 +1404,7 @@ class NetworkTrainer: if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(val_latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise From 874353296304c753b452511a412472f8a3e4ba09 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH 46/87] val --- library/config_util.py | 32 +++++++------ library/train_util.py | 20 ++++++-- train_network.py | 102 ++++++++++++++++++++++++++++------------- 3 files changed, 102 insertions(+), 52 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 1bf7ed95..cb2c5b68 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False - validation_seed: Optional[int] = None - validation_split: float = 0.0 + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 - + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -203,8 +204,9 @@ class ConfigSanitizer: "max_bucket_reso": int, "min_bucket_reso": int, "validation_seed": int, - "validation_split": float, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 1979207b..2364d62b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -122,6 +122,20 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] def split_train_val(paths, is_train, validation_split, validation_seed): if validation_seed is not None: @@ -1352,7 +1366,6 @@ class DreamBoothDataset(BaseDataset): self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed - self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1405,10 +1418,9 @@ class DreamBoothDataset(BaseDataset): return [], [] img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] diff --git a/train_network.py b/train_network.py index edd3ff94..48885503 100644 --- a/train_network.py +++ b/train_network.py @@ -130,7 +130,9 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None): + total_loss = 0.0 + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -167,37 +169,40 @@ class NetworkTrainer: args, noise_scheduler, latents ) - # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) + # Use input timesteps_list or use described timesteps above + timesteps_list = timesteps_list or [timesteps] + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - return loss + total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし + average_loss = total_loss / len(timesteps_list) + return average_loss def train(self, args): session_id = random.randint(0, 2**32) @@ -283,10 +288,10 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" if val_dataset_group is not None: - assert ( - val_dataset_group.is_latent_cacheable() - ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -430,6 +435,15 @@ class NetworkTrainer: num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], @@ -798,7 +812,6 @@ class NetworkTrainer: loss_recorder = train_util.LossRecorder() val_loss_recorder = train_util.LossRecorder() - del train_dataset_group # callback for step start @@ -848,7 +861,6 @@ class NetworkTrainer: on_step_start(text_encoder, unet) is_train = True loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) - accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = network.get_trainable_params() @@ -900,7 +912,25 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -912,7 +942,7 @@ class NetworkTrainer: with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): is_train = False - loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -933,6 +963,12 @@ class NetworkTrainer: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 From 449c1c5c502375713e609ad9e00e747b4013063a Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 2 Jan 2025 15:59:20 -0500 Subject: [PATCH 47/87] Adding modified train_util and config_util --- library/config_util.py | 1 - library/train_util.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index cb2c5b68..727e1a40 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -84,7 +84,6 @@ class BaseDatasetParams: tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None - network_multiplier: float = 1.0 debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 diff --git a/library/train_util.py b/library/train_util.py index 2364d62b..39433739 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1420,7 +1420,7 @@ class DreamBoothDataset(BaseDataset): img_paths = glob_images(subset.image_dir, "*") if self.validation_split > 0.0: img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] From 7470173044ca5b700bc4723709bd9c012e2216f3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:13:57 -0500 Subject: [PATCH 48/87] Remove defunct code for train_controlnet.py --- train_controlnet.py | 569 -------------------------------------------- 1 file changed, 569 deletions(-) diff --git a/train_controlnet.py b/train_controlnet.py index 09a911a0..365e35c8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -6,577 +6,8 @@ import logging logger = logging.getLogger(__name__) -<<<<<<< HEAD -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - # session_id = random.randint(0, 2**32) - # training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True - ) - - # DiffusersのControlNetが使用するデータを準備する - if args.v2: - unet.config = { - "act_fn": "silu", - "attention_head_dim": [5, 10, 20, 20], - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 1024, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "dual_cross_attention": False, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_class_embeds": None, - "only_cross_attention": False, - "out_channels": 4, - "sample_size": 96, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "use_linear_projection": True, - "upcast_attention": True, - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": True, - "class_embed_type": None, - "num_class_embeds": None, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - else: - unet.config = { - "act_fn": "silu", - "attention_head_dim": 8, - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 768, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "out_channels": 4, - "sample_size": 64, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": False, - "class_embed_type": None, - "num_class_embeds": None, - "upcast_attention": False, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - unet.config = SimpleNamespace(**unet.config) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - controlnet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = controlnet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) - - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.to(accelerator.device) - text_encoder.to(accelerator.device) - - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet - ) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps = train_util.get_timesteps(args, 0, noise_scheduler.config.num_train_timesteps, b_size) - huber_c = train_util.get_huber_c(args, noise_scheduler, timesteps.item(), latents.device) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and (args.save_state or args.save_state_on_train_end): - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - - return parser - -======= from library import train_util from train_control_net import setup_parser, train ->>>>>>> hina/feature/val-loss if __name__ == "__main__": logger.warning( From 534059dea517d44de387e7d467d64209f9dcfba2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:18:15 -0500 Subject: [PATCH 49/87] Typos and lingering is_train --- library/config_util.py | 2 +- library/train_util.py | 4 ---- train_network.py | 6 +++--- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a09d2c7c..418c179d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -535,7 +535,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/train_util.py b/library/train_util.py index bf1b6731..220d4702 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2092,7 +2092,6 @@ class FineTuningDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, - is_train: bool, validation_seed: int, validation_split: float, ) -> None: @@ -2312,7 +2311,6 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], - is_train: bool, batch_size: int, resolution, network_multiplier: float, @@ -2362,7 +2360,6 @@ class ControlNetDataset(BaseDataset): self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, - is_train, batch_size, resolution, network_multiplier, @@ -2382,7 +2379,6 @@ class ControlNetDataset(BaseDataset): self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images - self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed diff --git a/train_network.py b/train_network.py index 99b9717a..4bcfc0ac 100644 --- a/train_network.py +++ b/train_network.py @@ -380,11 +380,11 @@ class NetworkTrainer: else: return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) - choosen_timesteps_list = pick_timesteps_list() + chosen_timesteps_list = pick_timesteps_list() total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in choosen_timesteps_list: + for fixed_timestep in chosen_timesteps_list: fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) # Predict the noise residual @@ -447,7 +447,7 @@ class NetworkTrainer: total_loss += loss - return total_loss / len(choosen_timesteps_list) + return total_loss / len(chosen_timesteps_list) def train(self, args): session_id = random.randint(0, 2**32) From c8c3569df292109fe3be4d209c9f6131afe2ba5f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:26:45 -0500 Subject: [PATCH 50/87] Cleanup order, types, print to logger --- library/config_util.py | 7 +++---- library/train_util.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 418c179d..5a4d3aa2 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -485,7 +485,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) datasets.append(dataset) - val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split <= 0.0: continue @@ -503,7 +503,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - # print info def print_info(_datasets): info = "" for i, dataset in enumerate(_datasets): @@ -565,7 +564,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu print_info(datasets) if len(val_datasets) > 0: - print("Validation dataset") + logger.info("Validation dataset") print_info(val_datasets) if len(val_datasets) > 0: @@ -610,7 +609,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - logger.info(f"{info}") + logger.info(info) # make buckets first because it determines the length of dataset # and set the same seed for all datasets diff --git a/library/train_util.py b/library/train_util.py index 220d4702..782f57e8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1833,9 +1833,9 @@ class DreamBoothDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - debug_dataset, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2319,9 +2319,9 @@ class ControlNetDataset(BaseDataset): max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - debug_dataset: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2369,9 +2369,9 @@ class ControlNetDataset(BaseDataset): bucket_reso_steps, bucket_no_upscale, 1.0, + debug_dataset, validation_split, validation_seed, - debug_dataset ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) From fbfc2753eb7fa57724eb525ee65d851b5e80b8ea Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:53:12 -0500 Subject: [PATCH 51/87] Update text for train/reg with repeats --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 782f57e8..77a6a9f9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2050,11 +2050,11 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} images with repeating.") + logger.info(f"{num_train_images} train images with repeats.") self.num_train_images = num_train_images - logger.info(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images with repeats.") if num_train_images < num_reg_images: logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") From 58bfa36d0275d864d5a2d64c51632e808f789ddd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 02:00:28 -0500 Subject: [PATCH 52/87] Add seed help clarifying info --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4bcfc0ac..7d064d21 100644 --- a/train_network.py +++ b/train_network.py @@ -1639,7 +1639,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed / 検証シード" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証シード" ) parser.add_argument( "--validation_split", From 6604b36044a83f3531faed508096f3e6bfe48fc9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 02:04:59 -0500 Subject: [PATCH 53/87] Remove duplicate assignment --- library/train_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 77a6a9f9..3710c865 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -86,8 +86,6 @@ import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging, pil_resize - - setup_logging() import logging @@ -1841,8 +1839,6 @@ class DreamBoothDataset(BaseDataset): assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.validation_split = validation_split - self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight From 0522070d197d92745dbdb408d74c9c3f869bff76 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:20:25 -0500 Subject: [PATCH 54/87] Fix training, validation split, revert to using upstream implemenation --- library/config_util.py | 67 +++----------- library/custom_train_functions.py | 6 +- library/strategy_sd.py | 2 +- library/train_util.py | 143 +++++++++++++++++------------- train_network.py | 94 ++++++++++++-------- 5 files changed, 152 insertions(+), 160 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 5a4d3aa2..63d28c96 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -482,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -500,16 +500,16 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): + def print_info(_datasets, dataset_type: str): info = "" for i, dataset in enumerate(_datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ - [Dataset {i}] + [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} @@ -527,7 +527,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] + [Subset {j} of {dataset_type} {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} num_repeats: {subset.num_repeats} @@ -544,8 +544,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask} - custom_attributes: {subset.custom_attributes} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """), " ") if is_dreambooth: @@ -561,67 +561,22 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(info) - print_info(datasets) + print_info(datasets, "Dataset") if len(val_datasets) > 0: - logger.info("Validation dataset") - print_info(val_datasets) - - if len(val_datasets) > 0: - info = "" - - for i, dataset in enumerate(val_datasets): - info += dedent( - f"""\ - [Validation Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) - - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Validation Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - """ - ), - " ", - ) - - logger.info(info) + print_info(val_datasets, "Validation Dataset") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + logger.info(f"[Prepare dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - logger.info(f"[Validation Dataset {i}]") + logger.info(f"[Prepare validation dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9a7c21a3..ad3e69ff 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -455,7 +455,7 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): +def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): @@ -468,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4): # https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: @@ -484,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel diff --git a/library/strategy_sd.py b/library/strategy_sd.py index d0a3a68b..a44fc409 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,7 +40,7 @@ class SdTokenizeStrategy(TokenizeStrategy): text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text tokens_list = [] weights_list = [] diff --git a/library/train_util.py b/library/train_util.py index 3710c865..0f16a4f3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,15 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: + """ + Split the dataset into train and validation + + Shuffle the dataset based on the validation_seed or the current random seed. + For example if the split of 0.2 of 100 images. + [0:79] = 80 training images + [80:] = 20 validation images + """ if validation_seed is not None: print(f"Using validation seed: {validation_seed}") prevstate = random.getstate() @@ -156,9 +164,12 @@ def split_train_val(paths: List[str], is_train: bool, validation_split: float, v else: random.shuffle(paths) - if is_train: + # Split the dataset between training and validation + if is_training_dataset: + # Training dataset we split to the first part return paths[0:math.ceil(len(paths) * (1 - validation_split))] else: + # Validation dataset we split to the second part return paths[len(paths) - round(len(paths) * validation_split):] @@ -1822,6 +1833,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_training_dataset: bool, batch_size: int, resolution, network_multiplier: float, @@ -1843,6 +1855,7 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split @@ -1952,6 +1965,9 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2046,7 +2062,8 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeats.") + images_split_name = "train" if self.is_training_dataset else "validation" + logger.info(f"{num_train_images} {images_split_name} images with repeats.") self.num_train_images = num_train_images @@ -2411,8 +2428,12 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -4586,7 +4607,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) return args @@ -5880,55 +5900,35 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - Get a random timestep between the min and max timesteps - Can error (NotImplementedError) if the loss type is not supported - """ - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timesteps = timesteps.repeat(batch_size).to(device) - elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device) - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - return typing.cast(torch.IntTensor, timesteps) +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + return timesteps -def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]: - """ - Calculate the Huber convolution (huber_c) value - Huber loss is a loss function used in robust regression, that is less sensitive - to outliers in data than the squared error loss. - https://en.wikipedia.org/wiki/Huber_loss - """ - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000) - huber_c = math.exp(-alpha * timesteps.item()) - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = args.huber_c - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - elif args.loss_type == "l2": +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - return huber_c + return result -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor): +def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: """ Apply noise modifications like noise offset and multires noise """ @@ -5964,27 +5964,44 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep # Sample a random timestep for each image - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, device) + timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]: - """ - Unified noise, noisy_latents, timesteps and huber loss convolution calculations - """ - batch_size = latents.shape[0] +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) + if args.multires_noise_iterations: + noise = custom_train_functions.pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get("num_train_timesteps", 1000) if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - # A random timestep for each image in the batch - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, latents.device) - huber_c = get_huber_c(args, noise_scheduler, timesteps) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) - noise = make_noise(args, latents) - noisy_latents = get_noisy_latents(args, noise, noise_scheduler, latents, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: @@ -6015,6 +6032,8 @@ def conditional_loss( elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -6022,6 +6041,8 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/train_network.py b/train_network.py index 7d064d21..f870734f 100644 --- a/train_network.py +++ b/train_network.py @@ -205,10 +205,10 @@ class NetworkTrainer: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor: return vae.encode(images).latent_dist.sample() - def shift_scale_latents(self, args, latents): + def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor: return latents * self.vae_scale_factor def get_noise_pred_and_target( @@ -280,7 +280,7 @@ class NetworkTrainer: return noise_pred, target, timesteps, None - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -317,20 +317,21 @@ class NetworkTrainer: # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample()) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)) - latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor) + latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) + + latents = self.shift_scale_latents(args, latents) text_encoder_conds = [] @@ -384,22 +385,36 @@ class NetworkTrainer: total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in chosen_timesteps_list: - fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) + for fixed_timesteps in chosen_timesteps_list: + fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) # Predict the noise residual # and add noise to the latents # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep) + noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + fixed_timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timestep) + target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) else: target = noise @@ -418,7 +433,7 @@ class NetworkTrainer: accelerator, unet, noisy_latents, - timesteps, + fixed_timesteps, text_encoder_conds, batch, weight_dtype, @@ -427,7 +442,8 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): @@ -436,14 +452,7 @@ class NetworkTrainer: loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler) + loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) total_loss += loss @@ -526,8 +535,12 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) + + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly + train_util.debug_dataset(val_dataset_group) return if len(train_dataset_group) == 0: logger.error( @@ -753,10 +766,6 @@ class NetworkTrainer: # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) - # Not for sure here. - # if val_dataset_group is not None: - # val_dataset_group.set_max_train_steps(args.max_train_steps) - # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1304,7 +1313,7 @@ class NetworkTrainer: clean_memory_on_device(accelerator.device) for epoch in range(epoch_to_start, num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) @@ -1324,7 +1333,7 @@ class NetworkTrainer: continue with accelerator.accumulate(training_model): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1384,7 +1393,8 @@ class NetworkTrainer: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) # VALIDATION PER STEP should_validate = (args.validation_every_n_step is not None @@ -1401,7 +1411,7 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1409,10 +1419,12 @@ class NetworkTrainer: if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) logs = {"loss/average_val_loss": val_loss_recorder.moving_average} - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -1427,7 +1439,7 @@ class NetworkTrainer: ) for val_step, batch in enumerate(val_dataloader): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -1437,22 +1449,26 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) if len(val_dataloader) > 0 and is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) accelerator.wait_for_everyone() From 695f38962ce279adfee3fabb3479b84b1076b4e8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:25:12 -0500 Subject: [PATCH 55/87] Move get_huber_threshold_if_needed --- library/train_util.py | 44 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0f16a4f3..0907a8c0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5905,27 +5905,6 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: - if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): - return None - - b_size = timesteps.shape[0] - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - result = torch.exp(-alpha * timesteps) * args.huber_scale - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - result = result.to(timesteps.device) - elif args.huber_schedule == "constant": - result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - return result def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: @@ -6004,6 +5983,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch. return noise, noisy_latents, timesteps +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result + + def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: """ Add noise to the latents according to the noise magnitude at each timestep From 1f9ba40b8b70fd08e6b87a70727d5e789666a925 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:32:07 -0500 Subject: [PATCH 56/87] Add step break for validation epoch. Remove unused variable --- train_network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f870734f..ce34f26d 100644 --- a/train_network.py +++ b/train_network.py @@ -1439,6 +1439,9 @@ class NetworkTrainer: ) for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() @@ -1447,7 +1450,6 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) From 1c0ae306e551ede5bd162819debb4d80a7fe620b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:43:02 -0500 Subject: [PATCH 57/87] Add missing functions for training batch --- train_network.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index ce34f26d..377ddf48 100644 --- a/train_network.py +++ b/train_network.py @@ -318,7 +318,7 @@ class NetworkTrainer: # endregion def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: - + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -1333,6 +1333,11 @@ class NetworkTrainer: continue with accelerator.accumulate(training_model): + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: From bbf6bbd5ea27231066cec98b8bf2a65f162cb18f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:48:38 -0500 Subject: [PATCH 58/87] Use self.get_noise_pred_and_target and drop fixed timesteps --- flux_train_network.py | 7 ++- sd3_train_network.py | 3 +- train_network.py | 116 ++++++++++++------------------------------ 3 files changed, 40 insertions(+), 86 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975ba..b3aebecc 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -339,6 +339,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -375,7 +376,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -420,7 +421,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + with torch.set_grad_enabled(is_train and train_unet): + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) """ return model_pred diff --git a/sd3_train_network.py b/sd3_train_network.py index fb7711bd..c7417802 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -312,6 +312,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -339,7 +340,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): t5_attn_mask = None # call model - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index 377ddf48..61e6369a 100644 --- a/train_network.py +++ b/train_network.py @@ -223,6 +223,7 @@ class NetworkTrainer: network, weight_dtype, train_unet, + is_train=True ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -236,7 +237,7 @@ class NetworkTrainer: t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -317,7 +318,7 @@ class NetworkTrainer: # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -372,91 +373,40 @@ class NetworkTrainer: batch_size = latents.shape[0] - # Sample noise, - noise = train_util.make_noise(args, latents) - def pick_timesteps_list() -> torch.IntTensor: - if timesteps_list is None or timesteps_list == []: - return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1)) - else: - return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) + # Predict the noise residual + # and add noise to the latents + # with noise offset and/or multires noise if specified - chosen_timesteps_list = pick_timesteps_list() - total_loss = torch.zeros((batch_size, 1)).to(latents.device) + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train + ) - # Use input timesteps_list or use described timesteps above - for fixed_timesteps in chosen_timesteps_list: - fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) - else: - target = noise - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): - noise_pred_prior = self.call_unet( - args, - accelerator, - unet, - noisy_latents, - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - indices=diff_output_pr_indices, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - - huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし - - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights - - loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) - - total_loss += loss - - return total_loss / len(chosen_timesteps_list) + return loss.mean() def train(self, args): session_id = random.randint(0, 2**32) @@ -1416,7 +1366,7 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1447,7 +1397,7 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) From f4840ef29ef67878d7c7ccec92bdce89c3b61c6d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:52:07 -0500 Subject: [PATCH 59/87] Revert train_db.py --- train_db.py | 121 ++-------------------------------------------------- 1 file changed, 3 insertions(+), 118 deletions(-) diff --git a/train_db.py b/train_db.py index 398489ff..ad21f8d1 100644 --- a/train_db.py +++ b/train_db.py @@ -2,6 +2,7 @@ # XXX dropped option: fine_tune import argparse +import itertools import math import os from multiprocessing import Value @@ -41,73 +42,11 @@ import library.strategy_sd as strategy_sd setup_logging() import logging -import itertools logger = logging.getLogger(__name__) # perlin_noise, -def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss - - average_loss = total_loss / len(timesteps_list) - return average_loss def train(args): train_util.verify_training_args(args) @@ -150,10 +89,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -274,15 +212,6 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - val_dataloader = torch.utils.data.DataLoader( - val_dataset_group if val_dataset_group is not None else [], - shuffle=False, - batch_size=1, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -393,8 +322,6 @@ def train(args): accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -525,25 +452,6 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -634,30 +542,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_seed", - type=int, - default=None, - help="Validation seed" - ) - parser.add_argument( - "--validation_split", - type=float, - default=0.0, - help="Split for validation images out of the training dataset" - ) - parser.add_argument( - "--validation_every_n_step", - type=int, - default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" - ) - parser.add_argument( - "--max_validation_steps", - type=int, - default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" - ) + return parser From 1c63e7cc4979b528417b5bfe181e0a9ac119209c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:07:47 -0500 Subject: [PATCH 60/87] Cleanup unused code and formatting --- train_network.py | 85 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index 61e6369a..5a80d825 100644 --- a/train_network.py +++ b/train_network.py @@ -318,8 +318,27 @@ class NetworkTrainer: # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: - + def process_batch( + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, + tokenize_strategy: strategy_sd.SdTokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True + ) -> torch.Tensor: + """ + Process a batch for the network + """ with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -334,7 +353,6 @@ class NetworkTrainer: latents = self.shift_scale_latents(args, latents) - text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: @@ -371,13 +389,6 @@ class NetworkTrainer: if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - batch_size = latents.shape[0] - - - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, @@ -1288,7 +1299,23 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet + ) + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1366,12 +1393,26 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) - + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) + val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) - + if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) @@ -1397,7 +1438,21 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) From c64d1a22fc4ff25625873e50d63d480b297301c6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:30:21 -0500 Subject: [PATCH 61/87] Add validate_every_n_epochs, change name validate_every_n_steps --- train_network.py | 69 ++++++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/train_network.py b/train_network.py index 5a80d825..f3c8d8c9 100644 --- a/train_network.py +++ b/train_network.py @@ -1199,7 +1199,8 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() + val_step_loss_recorder = train_util.LossRecorder() + val_epoch_loss_recorder = train_util.LossRecorder() del train_dataset_group if val_dataset_group is not None: @@ -1299,7 +1300,8 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, + loss = self.process_batch( + batch, text_encoders, unet, network, @@ -1373,15 +1375,25 @@ class NetworkTrainer: if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm ) # accelerator.log(logs, step=global_step) accelerator.log(logs) # VALIDATION PER STEP - should_validate = (args.validation_every_n_step is not None - and global_step % args.validation_every_n_step == 0) - if validation_steps > 0 and should_validate: + should_validate_epoch = ( + args.validate_every_n_steps is not None + and global_step % args.validate_every_n_steps == 0 + ) + if validation_steps > 0 and should_validate_epoch: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1409,16 +1421,17 @@ class NetworkTrainer: is_train=False ) - val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/current_val_loss": loss.detach().item()} + logs = {"loss/step_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) - logs = {"loss/average_val_loss": val_loss_recorder.moving_average} + logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} # accelerator.log(logs, step=global_step) accelerator.log(logs) @@ -1426,12 +1439,18 @@ class NetworkTrainer: break # VALIDATION EPOCH - if len(val_dataloader) > 0: + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 + if args.validate_every_n_epochs is not None + else False + ) + + if should_validate_epoch and len(val_dataloader) > 0: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, - desc="validation steps" + desc="epoch validation steps" ) for val_step, batch in enumerate(val_dataloader): @@ -1455,18 +1474,18 @@ class NetworkTrainer: ) current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/validation_current": current_loss} + logs = {"loss/epoch_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_average": avr_loss} + avr_loss: float = val_epoch_loss_recorder.moving_average + logs = {"loss/epoch_validation_average": avr_loss} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) @@ -1475,12 +1494,6 @@ class NetworkTrainer: logs = {"loss/epoch_average": loss_recorder.moving_average} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) - - if len(val_dataloader) > 0 and is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) accelerator.wait_for_everyone() @@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" ) parser.add_argument( - "--validation_every_n_step", + "--validate_every_n_steps", type=int, default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + help="Run validation dataset every N steps" + ) + parser.add_argument( + "--validate_every_n_epochs", + type=int, + default=None, + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" ) parser.add_argument( "--max_validation_steps", From f8850296c83ef2091bf1cb0f6e9ba462adfd9045 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:34:10 -0500 Subject: [PATCH 62/87] Fix validate epoch, cleanup imports --- train_network.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index f3c8d8c9..11bba71e 100644 --- a/train_network.py +++ b/train_network.py @@ -3,15 +3,13 @@ import argparse import math import os import typing -from typing import List, Optional, Union +from typing import Any, List import sys import random import time import json from multiprocessing import Value -from typing import Any, List import toml -import itertools from tqdm import tqdm @@ -23,8 +21,8 @@ init_ipex() from accelerate.utils import set_seed from accelerate import Accelerator -from diffusers import DDPMScheduler, AutoencoderKL -from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers import DDPMScheduler +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util @@ -49,7 +47,6 @@ from library.utils import setup_logging, add_logging_arguments setup_logging() import logging -import itertools logger = logging.getLogger(__name__) @@ -1442,7 +1439,7 @@ class NetworkTrainer: should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None - else False + else True ) if should_validate_epoch and len(val_dataloader) > 0: From fcb2ff010cf2e42c50b3745a17317f2d4b4319d9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:39:32 -0500 Subject: [PATCH 63/87] Clean up some validation help documentation --- train_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 11bba71e..af180c45 100644 --- a/train_network.py +++ b/train_network.py @@ -1677,7 +1677,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証シード" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" ) parser.add_argument( "--validation_split", @@ -1689,19 +1689,19 @@ def setup_parser() -> argparse.ArgumentParser: "--validate_every_n_steps", type=int, default=None, - help="Run validation dataset every N steps" + help="Run validation on validation dataset every N steps if a validation dataset is available / 検証データセットが利用可能な場合は、Nステップごとに検証データセットの検証を実行します" ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" ) return parser From 742bee9738e9d190a39f5a36adf4515fa415e9b7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 17:34:23 -0500 Subject: [PATCH 64/87] Set validation steps in multiple lines for readability --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index af180c45..d0596fca 100644 --- a/train_network.py +++ b/train_network.py @@ -1251,7 +1251,11 @@ class NetworkTrainer: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) + if args.max_validation_steps is not None + else len(val_dataloader) + ) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1689,7 +1693,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps if a validation dataset is available / 検証データセットが利用可能な場合は、Nステップごとに検証データセットの検証を実行します" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" ) parser.add_argument( "--validate_every_n_epochs", From 1231f5114ccd6a0a26a53da82b89083299ccc333 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Jan 2025 22:31:41 -0500 Subject: [PATCH 65/87] Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code --- library/train_util.py | 70 ++++++++++++++++--------------------------- train_network.py | 66 ++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 77 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0907a8c0..b8894752 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5900,51 +5900,9 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - return timesteps - - - - -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: - """ - Apply noise modifications like noise offset and multires noise - """ - if args.noise_offset: - if args.noise_offset_random_strength: - noise_offset = torch.rand(1, device=latents.device) * args.noise_offset - else: - noise_offset = args.noise_offset - noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) - if args.multires_noise_iterations: - noise = custom_train_functions.pyramid_noise_like( - noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount - ) - return noise - - -def make_noise(args, latents: torch.Tensor) -> torch.FloatTensor: - """ - Make a noise tensor to denoise and apply noise modifications (noise offset, multires noise). See `modify_noise` - """ - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - noise = modify_noise(args, noise, latents) - - return typing.cast(torch.FloatTensor, noise) - - -def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - From args, produce random timesteps for each image in the batch - """ - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep - - # Sample a random timestep for each image - timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) - + timesteps = timesteps.long().to(device) return timesteps @@ -6457,6 +6415,30 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): + """ + Initialize experiment trackers with tracker specific behaviors + """ + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + default_tracker_name if args.log_tracker_name is None else args.log_tracker_name, + config=get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) + + # Define specific metrics to handle validation and epochs "steps" + wandb_tracker.define_metric("epoch", hidden=True) + wandb_tracker.define_metric("val_step", hidden=True) + # endregion diff --git a/train_network.py b/train_network.py index d0596fca..199f589b 100644 --- a/train_network.py +++ b/train_network.py @@ -327,8 +327,8 @@ class NetworkTrainer: weight_dtype, accelerator, args, - text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, - tokenize_strategy: strategy_sd.SdTokenizeStrategy, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True @@ -1183,17 +1183,7 @@ class NetworkTrainer: noise_scheduler = self.get_noise_scheduler(args, accelerator.device) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) + train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() @@ -1386,15 +1376,14 @@ class NetworkTrainer: mean_norm, maximum_norm ) - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + accelerator.log(logs, step=global_step) # VALIDATION PER STEP - should_validate_epoch = ( + should_validate_step = ( args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_epoch: + if validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1406,6 +1395,9 @@ class NetworkTrainer: if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1428,18 +1420,22 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/step_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/step/current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) - logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + if is_tracking: + logs = { + "loss/validation/step/average": val_step_loss_recorder.moving_average, + } + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - # VALIDATION EPOCH + # EPOCH VALIDATION should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None @@ -1458,6 +1454,9 @@ class NetworkTrainer: if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1480,21 +1479,22 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/epoch_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step + } + accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/epoch_validation_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) accelerator.wait_for_everyone() From 556f3f1696eadcc16ee77425243b732a84c7a2aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 13:41:15 -0500 Subject: [PATCH 66/87] Fix documentation, remove unused function, fix bucket reso for sd1.5, fix multiple datasets --- library/config_util.py | 6 +++--- library/train_util.py | 25 ++++--------------------- train_network.py | 5 +---- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 63d28c96..de1e154a 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -481,9 +481,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) - datasets.append(dataset) + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/library/train_util.py b/library/train_util.py index b8894752..62aae37e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -152,11 +152,11 @@ def split_train_val(paths: List[str], is_training_dataset: bool, validation_spli Shuffle the dataset based on the validation_seed or the current random seed. For example if the split of 0.2 of 100 images. - [0:79] = 80 training images + [0:80] = 80 training images [80:] = 20 validation images """ if validation_seed is not None: - print(f"Using validation seed: {validation_seed}") + logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) random.shuffle(paths) @@ -5900,8 +5900,8 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = timesteps.long().to(device) return timesteps @@ -5964,23 +5964,6 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler return result -def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: - """ - Add noise to the latents according to the noise magnitude at each timestep - (this is the forward diffusion process) - """ - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma - else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - return noisy_latents - - def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): diff --git a/train_network.py b/train_network.py index 199f589b..7dbd12e8 100644 --- a/train_network.py +++ b/train_network.py @@ -125,10 +125,7 @@ class NetworkTrainer: return logs def assert_extra_args(self, args, train_dataset_group): - # train_dataset_group.verify_bucket_reso_steps(64) - # TODO: Number of bucket reso steps may differ for each model, so a static number won't work - # and prevents models like SD1.5 with 64 - pass + train_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 9fde0d797282c0cb9fcea01682e2e6e2eece47bc Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:38:20 -0500 Subject: [PATCH 67/87] Handle tuple return from generate_dataset_group_by_blueprint --- fine_tune.py | 4 ++-- flux_train.py | 3 ++- flux_train_control_net.py | 4 ++-- library/config_util.py | 2 +- sd3_train.py | 3 ++- sdxl_train.py | 3 ++- sdxl_train_control_net.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- tools/cache_latents.py | 3 ++- tools/cache_text_encoder_outputs.py | 3 ++- train_control_net.py | 2 +- train_db.py | 3 ++- train_textual_inversion.py | 3 ++- train_textual_inversion_XTI.py | 2 +- 15 files changed, 24 insertions(+), 17 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 17608706..6be2f98c 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,9 +91,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train.py b/flux_train.py index fced3bef..6f98adea 100644 --- a/flux_train.py +++ b/flux_train.py @@ -138,9 +138,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d..54dec2a7 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -126,9 +126,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/library/config_util.py b/library/config_util.py index de1e154a..834d6bfa 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -467,7 +467,7 @@ class BlueprintGenerator: return default_value -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/sd3_train.py b/sd3_train.py index 120455e7..3bff6a50 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -149,9 +149,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train.py b/sdxl_train.py index b9d52924..a60f6df6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -176,9 +176,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03ca..c6e8136f 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -114,7 +114,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b7..00e51a67 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -123,7 +123,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372bef..63457cc6 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -103,7 +103,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index c034f949..515ece98 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5888b8e3..00459658 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/train_control_net.py b/train_control_net.py index 177d2b11..ba016ac5 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -100,7 +100,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_db.py b/train_db.py index ad21f8d1..edd67403 100644 --- a/train_db.py +++ b/train_db.py @@ -89,9 +89,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859..113f3599 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,9 +320,10 @@ class TextualInversionTrainer: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None self.assert_extra_args(args, train_dataset_group) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b4231..6ff97d03 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -239,7 +239,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) current_epoch = Value("i", 0) current_step = Value("i", 0) From 1e61392cf2f601e1c66aaede6846ef70f599c34f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:43:26 -0500 Subject: [PATCH 68/87] Revert bucket_reso_steps to correct 64 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 7dbd12e8..7e9f1265 100644 --- a/train_network.py +++ b/train_network.py @@ -125,7 +125,7 @@ class NetworkTrainer: return logs def assert_extra_args(self, args, train_dataset_group): - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From d6f158ddf6a3631df7db10ac97453b12de8eadbe Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:48:05 -0500 Subject: [PATCH 69/87] Fix incorrect destructoring for load_abritrary_dataset --- fine_tune.py | 3 ++- flux_train_control_net.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 6be2f98c..e1ed4749 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -93,7 +93,8 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 54dec2a7..cecd0001 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -128,7 +128,8 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) From 264167fa1636c79f106c63c3cdb67b6bee80aceb Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Jan 2025 12:43:58 -0500 Subject: [PATCH 70/87] Apply is_training_dataset only to DreamBoothDataset. Add validation_split check and warning --- library/config_util.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 834d6bfa..a2e07dc6 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -471,36 +471,49 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": True} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.params.validation_split <= 0.0: + if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: + logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") continue + + # if the dataset isn't setting a validation split, there is no current validation dataset + if dataset_blueprint.params.validation_split == 0.0: + continue + + extra_dataset_params = {} if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": False} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) val_datasets.append(dataset) def print_info(_datasets, dataset_type: str): From 4c61adc9965df6861ae3705c96143f4299074744 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 13:18:26 -0500 Subject: [PATCH 71/87] Add divergence to logs Divergence is the difference between training and validation to allow a clear value to indicate the difference between the two in the logs. --- train_network.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 7e9f1265..5ed92b7e 100644 --- a/train_network.py +++ b/train_network.py @@ -1418,14 +1418,16 @@ class NetworkTrainer: if is_tracking: logs = { - "loss/validation/step/current": current_loss, + "loss/validation/step_current": current_loss, "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step/average": val_step_loss_recorder.moving_average, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) @@ -1485,7 +1487,12 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + logs = { + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1 + } accelerator.log(logs, step=global_step) # END OF EPOCH From 2bbb40ce51d5be3ce8c3e1990d30455201f9e852 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:29:50 -0500 Subject: [PATCH 72/87] Fix regularization images with validation Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention --- library/train_util.py | 33 +++++++++++++++++++++++++++++++-- train_network.py | 7 +++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 62aae37e..6d3a772b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,12 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val( + paths: List[str], + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None +) -> List[str]: """ Split the dataset into train and validation @@ -1830,6 +1835,9 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + # The is_training_dataset defines the type of dataset, training or validation + # if is_training_dataset is True -> training dataset + # if is_training_dataset is False -> validation dataset def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1965,8 +1973,29 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + # We want to create a training and validation split. This should be improved in the future + # to allow a clearer distinction between training and validation. This can be seen as a + # short-term solution to limit what is necessary to implement validation datasets + # + # We split the dataset for the subset based on if we are doing a validation split + # The self.is_training_dataset defines the type of dataset, training or validation + # if self.is_training_dataset is True -> training dataset + # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + # For regularization images we do not want to split this dataset. + if subset.is_reg is True: + # Skip any validation dataset for regularization images + if self.is_training_dataset is False: + img_paths = [] + # Otherwise the img_paths remain as original img_paths and no split + # required for training images dataset of regularization images + else: + img_paths = split_train_val( + img_paths, + self.is_training_dataset, + self.validation_split, + self.validation_seed + ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") diff --git a/train_network.py b/train_network.py index 5ed92b7e..605dbc60 100644 --- a/train_network.py +++ b/train_network.py @@ -898,6 +898,7 @@ class NetworkTrainer: accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -917,6 +918,7 @@ class NetworkTrainer: "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -964,6 +966,11 @@ class NetworkTrainer: "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata From 0456858992909ca0b821ec1b2ca40fa633113224 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:47:49 -0500 Subject: [PATCH 73/87] Fix validate_every_n_steps always running first step --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 605dbc60..75e36dca 100644 --- a/train_network.py +++ b/train_network.py @@ -1385,6 +1385,7 @@ class NetworkTrainer: # VALIDATION PER STEP should_validate_step = ( args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if validation_steps > 0 and should_validate_step: From ee9265cf2678df5c9dfa6c1148d20fb738a9e6ce Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:56:35 -0500 Subject: [PATCH 74/87] Fix validate_every_n_steps for gradient accumulation --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 75e36dca..2f3203c9 100644 --- a/train_network.py +++ b/train_network.py @@ -1388,7 +1388,7 @@ class NetworkTrainer: and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_step: + if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( From 25929dd0d733144859008479c374968102e5d3a3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 15:38:57 -0500 Subject: [PATCH 75/87] Remove Validating... print to fix output layout --- train_network.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/train_network.py b/train_network.py index 2f3203c9..e7d93a10 100644 --- a/train_network.py +++ b/train_network.py @@ -1389,8 +1389,6 @@ class NetworkTrainer: and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: - accelerator.print("Validating バリデーション処理...") - val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, @@ -1450,7 +1448,6 @@ class NetworkTrainer: ) if should_validate_epoch and len(val_dataloader) > 0: - accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, From b489082495ba6779385f282797227799413715f5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 16:42:04 -0500 Subject: [PATCH 76/87] Disable repeats for validation datasets --- library/train_util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6d3a772b..4d143c37 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2055,9 +2055,10 @@ class DreamBoothDataset(BaseDataset): num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: - if subset.num_repeats < 1: + num_repeats = subset.num_repeats if self.is_training_dataset else 1 + if num_repeats < 1: logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}" ) continue @@ -2075,12 +2076,12 @@ class DreamBoothDataset(BaseDataset): continue if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) + num_reg_images += num_repeats * len(img_paths) else: - num_train_images += subset.num_repeats * len(img_paths) + num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: From 345daaa986cdbcaedb6840997390f3d86846d677 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 17 Jan 2025 23:22:38 +0900 Subject: [PATCH 77/87] update README for merging --- README-ja.md | 8 ++++++-- README.md | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/README-ja.md b/README-ja.md index 27cc56c3..60249f61 100644 --- a/README-ja.md +++ b/README-ja.md @@ -36,6 +36,8 @@ Python 3.10.6およびGitが必要です。 - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe - git: https://git-scm.com/download/win +Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。 + PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。 (venvに限らずスクリプトの実行が可能になりますので注意してください。) @@ -45,7 +47,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。 +スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) @@ -67,10 +69,12 @@ accelerate config コマンドプロンプトでも同一です。 -注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 +注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。 +PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。 + accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) ```txt diff --git a/README.md b/README.md index 73564cb2..6beee5e3 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ This repository contains the scripts for: The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below. -The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work. +The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers. ## Links to usage documentation @@ -52,6 +52,8 @@ Python 3.10.6 and Git: - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe - git: https://git-scm.com/download/win +Python 3.10.x, 3.11.x, and 3.12.x will work but not tested. + Give unrestricted script access to powershell so venv can work: - Open an administrator powershell window @@ -78,10 +80,12 @@ accelerate config If `python -m venv` shows only `python`, change `python` to `py`. -__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. +Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`. +If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version. +