From 6889ee2b85aa6af04ae3a68250d88d46ab9417bf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Aug 2023 19:02:36 +0900 Subject: [PATCH] add warning for bucket_reso_steps with SDXL --- finetune/prepare_buckets_latents.py | 4 ++++ library/train_util.py | 13 +++++++++++++ sdxl_train.py | 2 ++ sdxl_train_network.py | 2 ++ sdxl_train_textual_inversion.py | 2 ++ 5 files changed, 23 insertions(+) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1dde2294..af08c537 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -52,6 +52,10 @@ def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + if args.bucket_reso_steps % 32 > 0: + print( + f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" + ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] diff --git a/library/train_util.py b/library/train_util.py index e88a3dcf..82ac9dbf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -800,6 +800,12 @@ class BaseDataset(torch.utils.data.Dataset): random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() + def verify_bucket_reso_steps(self, min_steps: int): + assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, ( + f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n" + + f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります" + ) + def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) @@ -1831,6 +1837,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) + def verify_bucket_reso_steps(self, min_steps: int): + for dataset in self.datasets: + dataset.verify_bucket_reso_steps(min_steps) + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -2020,6 +2030,9 @@ class MinimalDataset(BaseDataset): self.is_reg = False self.image_dir = "dummy" # for metadata + def verify_bucket_reso_steps(self, min_steps: int): + pass + def is_latent_cacheable(self) -> bool: return False diff --git a/sdxl_train.py b/sdxl_train.py index 2ca14931..e62bc377 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -98,6 +98,8 @@ def train(args): ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + train_dataset_group.verify_bucket_reso_steps(32) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group, True) return diff --git a/sdxl_train_network.py b/sdxl_train_network.py index e3254be0..8d3a81c3 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -23,6 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + train_dataset_group.verify_bucket_reso_steps(32) + def load_target_model(self, args, weight_dtype, accelerator): ( load_stable_diffusion_format, diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 1ddfd92b..123ca35a 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -19,6 +19,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine super().assert_extra_args(args, train_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) + train_dataset_group.verify_bucket_reso_steps(32) + def load_target_model(self, args, weight_dtype, accelerator): ( load_stable_diffusion_format,