Merge pull request #2216 from kohya-ss/fix-sdxl-textual-inversion-training-disable-mmap

fix: disable_mmap_safetensors not defined in SDXL TI training
This commit is contained in:
Kohya S.
2025-09-29 20:55:02 +09:00
committed by GitHub
2 changed files with 18 additions and 15 deletions

View File

@@ -327,15 +327,18 @@ def save_sd_model_on_epoch_end_or_stepwise(
) )
def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument( if support_text_encoder_caching:
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" parser.add_argument(
) "--cache_text_encoder_outputs",
parser.add_argument( action="store_true",
"--cache_text_encoder_outputs_to_disk", help="cache text encoder outputs / text encoderの出力をキャッシュする",
action="store_true", )
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", parser.add_argument(
) "--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument( parser.add_argument(
"--disable_mmap_load_safetensors", "--disable_mmap_load_safetensors",
action="store_true", action="store_true",
@@ -343,7 +346,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
) )
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.clip_skip is not None: if args.clip_skip is not None:
@@ -366,7 +369,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
not hasattr(args, "weighted_captions") or not args.weighted_captions not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching: if support_text_encoder_caching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs = True
logger.warning( logger.warning(

View File

@@ -5,6 +5,7 @@ import regex
import torch import torch
from library.device_utils import init_ipex from library.device_utils import init_ipex
init_ipex() init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util from library import sdxl_model_util, sdxl_train_util, train_util
@@ -19,8 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.is_sdxl = True self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group): def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group) # super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
train_dataset_group.verify_bucket_reso_steps(32) train_dataset_group.verify_bucket_reso_steps(32)
@@ -122,8 +123,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser() parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser return parser