From 6889ee2b85aa6af04ae3a68250d88d46ab9417bf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Aug 2023 19:02:36 +0900 Subject: [PATCH 1/4] add warning for bucket_reso_steps with SDXL --- finetune/prepare_buckets_latents.py | 4 ++++ library/train_util.py | 13 +++++++++++++ sdxl_train.py | 2 ++ sdxl_train_network.py | 2 ++ sdxl_train_textual_inversion.py | 2 ++ 5 files changed, 23 insertions(+) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1dde2294..af08c537 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -52,6 +52,10 @@ def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + if args.bucket_reso_steps % 32 > 0: + print( + f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" + ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] diff --git a/library/train_util.py b/library/train_util.py index e88a3dcf..82ac9dbf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -800,6 +800,12 @@ class BaseDataset(torch.utils.data.Dataset): random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() + def verify_bucket_reso_steps(self, min_steps: int): + assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, ( + f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n" + + f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります" + ) + def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) @@ -1831,6 +1837,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) + def verify_bucket_reso_steps(self, min_steps: int): + for dataset in self.datasets: + dataset.verify_bucket_reso_steps(min_steps) + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -2020,6 +2030,9 @@ class MinimalDataset(BaseDataset): self.is_reg = False self.image_dir = "dummy" # for metadata + def verify_bucket_reso_steps(self, min_steps: int): + pass + def is_latent_cacheable(self) -> bool: return False diff --git a/sdxl_train.py b/sdxl_train.py index 2ca14931..e62bc377 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -98,6 +98,8 @@ def train(args): ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + train_dataset_group.verify_bucket_reso_steps(32) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group, True) return diff --git a/sdxl_train_network.py b/sdxl_train_network.py index e3254be0..8d3a81c3 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -23,6 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + train_dataset_group.verify_bucket_reso_steps(32) + def load_target_model(self, args, weight_dtype, accelerator): ( load_stable_diffusion_format, diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 1ddfd92b..123ca35a 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -19,6 +19,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine super().assert_extra_args(args, train_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) + train_dataset_group.verify_bucket_reso_steps(32) + def load_target_model(self, args, weight_dtype, accelerator): ( load_stable_diffusion_format, From 3307ccb2dc46c96d370cf57437002ff3228dcfd4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Aug 2023 20:35:46 +0900 Subject: [PATCH 2/4] revert default noise offset to 0 (None) in sdxl --- library/sdxl_train_util.py | 26 +++++++++++++------------- library/train_util.py | 12 ++++++------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1f849275..12bcf6d2 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -13,7 +13,7 @@ from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeigh TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" -DEFAULT_NOISE_OFFSET = 0.0357 +# DEFAULT_NOISE_OFFSET = 0.0357 def load_target_model(args, accelerator, model_version: str, weight_dtype): @@ -312,18 +312,18 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin if args.clip_skip is not None: print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") - if args.multires_noise_iterations: - print( - f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" - ) - else: - if args.noise_offset is None: - args.noise_offset = DEFAULT_NOISE_OFFSET - elif args.noise_offset != DEFAULT_NOISE_OFFSET: - print( - f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" - ) - print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + # if args.multires_noise_iterations: + # print( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # print( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( not hasattr(args, "weighted_captions") or not args.weighted_captions diff --git a/library/train_util.py b/library/train_util.py index 82ac9dbf..0b40e3ed 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2994,11 +2994,11 @@ def verify_training_args(args: argparse.Namespace): ) # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time - # Listを使って数えてもいいけど並べてしまえ - if args.noise_offset is not None and args.multires_noise_iterations is not None: - raise ValueError( - "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" - ) + # # Listを使って数えてもいいけど並べてしまえ + # if args.noise_offset is not None and args.multires_noise_iterations is not None: + # raise ValueError( + # "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" + # ) # if args.noise_offset is not None and args.perlin_noise is not None: # raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません") # if args.perlin_noise is not None and args.multires_noise_iterations is not None: @@ -4281,7 +4281,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: + if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount ) From 8415014de6979d8f0f67535a7056193e2b25a386 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Aug 2023 21:31:55 +0900 Subject: [PATCH 3/4] suppress waning for scheduler args #748 --- sdxl_gen_img.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 209e71a7..a9d7fc4f 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1309,7 +1309,10 @@ def main(args): # schedulerを用意する sched_init_args = {} + has_steps_offset = True + has_clip_sample = True scheduler_num_noises_per_step = 1 + if args.sampler == "ddim": scheduler_cls = DDIMScheduler scheduler_module = diffusers.schedulers.scheduling_ddim @@ -1319,32 +1322,48 @@ def main(args): elif args.sampler == "pndm": scheduler_cls = PNDMScheduler scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False elif args.sampler == "lms" or args.sampler == "k_lms": scheduler_cls = LMSDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False elif args.sampler == "euler" or args.sampler == "k_euler": scheduler_cls = EulerDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False elif args.sampler == "euler_a" or args.sampler == "k_euler_a": scheduler_cls = EulerAncestralDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": scheduler_cls = DPMSolverMultistepScheduler sched_init_args["algorithm_type"] = args.sampler scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False elif args.sampler == "dpmsingle": scheduler_cls = DPMSolverSinglestepScheduler scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False elif args.sampler == "heun": scheduler_cls = HeunDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": scheduler_cls = KDPM2DiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete scheduler_num_noises_per_step = 2 + has_clip_sample = False + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False # samplerの乱数をあらかじめ指定するための処理 @@ -1397,10 +1416,11 @@ def main(args): **sched_init_args, ) - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") - scheduler.config.clip_sample = True + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True # deviceを決定する device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない From e2c2689f5c82de2aa0aa3ed81348d54b9e9bc288 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 12 Aug 2023 13:13:59 +0900 Subject: [PATCH 4/4] support block lr for U-Net --- sdxl_train.py | 125 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 110 insertions(+), 15 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index e62bc377..195467b0 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -5,6 +5,7 @@ import gc import math import os from multiprocessing import Value +from typing import List import toml from tqdm import tqdm @@ -30,6 +31,67 @@ from library.custom_train_functions import ( from library.sdxl_original_unet import SdxlUNet2DConditionModel +UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23 + + +def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]: + block_params = [[] for _ in range(len(block_lrs))] + + for i, (name, param) in enumerate(unet.named_parameters()): + if name.startswith("time_embed.") or name.startswith("label_emb."): + block_index = 0 # 0 + elif name.startswith("input_blocks."): # 1-9 + block_index = 1 + int(name.split(".")[1]) + elif name.startswith("middle_block."): # 10-12 + block_index = 10 + int(name.split(".")[1]) + elif name.startswith("output_blocks."): # 13-21 + block_index = 13 + int(name.split(".")[1]) + elif name.startswith("out."): # 22 + block_index = 22 + else: + raise ValueError(f"unexpected parameter name: {name}") + + block_params[block_index].append(param) + + params_to_optimize = [] + for i, params in enumerate(block_params): + if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0 + continue + params_to_optimize.append({"params": params, "lr": block_lrs[i]}) + + return params_to_optimize + + +def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type): + lrs = lr_scheduler.get_last_lr() + + lr_index = 0 + block_index = 0 + while lr_index < len(lrs): + if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR: + name = f"block{block_index}" + if block_lrs[block_index] == 0: + block_index += 1 + continue + elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR: + name = "text_encoder1" + elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1: + name = "text_encoder2" + else: + raise ValueError(f"unexpected block_index: {block_index}") + + block_index += 1 + + logs["lr/" + name] = float(lrs[lr_index]) + + if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): + logs["lr/d*lr/" + name] = ( + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] + ) + + lr_index += 1 + + def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -40,6 +102,14 @@ def train(args): not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.block_lr: + block_lrs = [float(lr) for lr in args.block_lr.split(",")] + assert ( + len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + else: + block_lrs = None + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -235,15 +305,28 @@ def train(args): for m in training_models: m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params - # calculate number of trainable parameters - n_params = 0 - for p in params: - n_params += p.numel() + if block_lrs is None: + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params + + # calculate number of trainable parameters + n_params = 0 + for p in params: + n_params += p.numel() + else: + params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net + for m in training_models[1:]: # Text Encoders if exists + params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() + accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -528,13 +611,18 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy" - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) + logs = {"loss": current_loss} + if block_lrs is None: + logs["lr"] = float(lr_scheduler.get_last_lr()[0]) + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + else: + append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) + accelerator.log(logs, step=global_step) # TODO moving averageにする @@ -638,6 +726,13 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--block_lr", + type=str, + default=None, + help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + ) return parser