diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f78d9424..5ac9eb3b 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,15 +327,18 @@ def save_sd_model_on_epoch_end_or_stepwise( ) -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): + if support_text_encoder_caching: + parser.add_argument( + "--cache_text_encoder_outputs", + action="store_true", + help="cache text encoder outputs / text encoderの出力をキャッシュする", + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", @@ -343,7 +346,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): ) -def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): +def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.clip_skip is not None: @@ -366,7 +369,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin not hasattr(args, "weighted_captions") or not args.weighted_captions ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" - if supportTextEncoderCaching: + if support_text_encoder_caching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True logger.warning( diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e2..d8422f08 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -5,6 +5,7 @@ import regex import torch from library.device_utils import init_ipex + init_ipex() from library import sdxl_model_util, sdxl_train_util, train_util @@ -19,8 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) + # super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64 + sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) train_dataset_group.verify_bucket_reso_steps(32) @@ -122,8 +123,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def setup_parser() -> argparse.ArgumentParser: parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) return parser