fix invalid args checking in sdxl TI training

This commit is contained in:
Kohya S
2023-07-20 14:50:57 +09:00
parent 771f33d17d
commit fc276a51fb
2 changed files with 9 additions and 8 deletions

View File

@@ -342,7 +342,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
)
def verify_sdxl_training_args(args: argparse.Namespace):
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
@@ -367,12 +367,13 @@ def verify_sdxl_training_args(args: argparse.Namespace):
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)
def sample_images(*args, **kwargs):

View File

@@ -16,7 +16,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
def load_target_model(self, args, weight_dtype, accelerator):
(