From fc276a51fbca83f45440a6211b3c372f04319892 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 20 Jul 2023 14:50:57 +0900 Subject: [PATCH] fix invalid args checking in sdxl TI training --- library/sdxl_train_util.py | 15 ++++++++------- sdxl_train_textual_inversion.py | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 100b1cf8..34312afc 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -342,7 +342,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): ) -def verify_sdxl_training_args(args: argparse.Namespace): +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") @@ -367,12 +367,13 @@ def verify_sdxl_training_args(args: argparse.Namespace): not hasattr(args, "weighted_captions") or not args.weighted_captions ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" - if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - args.cache_text_encoder_outputs = True - print( - "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " - + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" - ) + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + print( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) def sample_images(*args, **kwargs): diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 54328000..a2515051 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -16,7 +16,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) - sdxl_train_util.verify_sdxl_training_args(args) + sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) def load_target_model(self, args, weight_dtype, accelerator): (