mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix invalid args checking in sdxl TI training
This commit is contained in:
@@ -342,7 +342,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
@@ -367,12 +367,13 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
print(
|
||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||
)
|
||||
if supportTextEncoderCaching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
print(
|
||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||
)
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
|
||||
@@ -16,7 +16,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user