mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix to work cache latents/text encoder outputs
This commit is contained in:
@@ -4064,15 +4064,18 @@ def verify_command_line_training_args(args: argparse.Namespace):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_high_vram(args: argparse.Namespace):
|
||||||
|
if args.highvram:
|
||||||
|
logger.info("highvram is enabled / highvramが有効です")
|
||||||
|
global HIGH_VRAM
|
||||||
|
HIGH_VRAM = True
|
||||||
|
|
||||||
def verify_training_args(args: argparse.Namespace):
|
def verify_training_args(args: argparse.Namespace):
|
||||||
r"""
|
r"""
|
||||||
Verify training arguments. Also reflect highvram option to global variable
|
Verify training arguments. Also reflect highvram option to global variable
|
||||||
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
|
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
|
||||||
"""
|
"""
|
||||||
if args.highvram:
|
enable_high_vram(args)
|
||||||
print("highvram is enabled / highvramが有効です")
|
|
||||||
global HIGH_VRAM
|
|
||||||
HIGH_VRAM = True
|
|
||||||
|
|
||||||
if args.v_parameterization and not args.v2:
|
if args.v_parameterization and not args.v2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from accelerate.utils import set_seed
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
|
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
|
||||||
from library import train_util
|
from library import train_util
|
||||||
from library import sdxl_train_util
|
from library import sdxl_train_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
@@ -30,7 +30,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa
|
|||||||
else:
|
else:
|
||||||
is_schnell = False
|
is_schnell = False
|
||||||
|
|
||||||
if is_sd or is_sdxl:
|
if is_sd:
|
||||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||||
elif is_sdxl:
|
elif is_sdxl:
|
||||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||||
@@ -51,6 +51,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa
|
|||||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
train_util.enable_high_vram(args)
|
||||||
|
|
||||||
# assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
# assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
||||||
args.cache_latents = True
|
args.cache_latents = True
|
||||||
@@ -161,10 +162,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
train_util.add_masked_loss_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
parser.add_argument(
|
flux_train_utils.add_flux_train_arguments(parser)
|
||||||
"--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル"
|
|
||||||
)
|
|
||||||
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||||
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
|
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
from library.utils import setup_logging, add_logging_arguments
|
from library.utils import setup_logging, add_logging_arguments
|
||||||
from tools import cache_latents
|
from cache_latents import set_tokenize_strategy
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
train_util.enable_high_vram(args)
|
||||||
|
|
||||||
args.cache_text_encoder_outputs = True
|
args.cache_text_encoder_outputs = True
|
||||||
args.cache_text_encoder_outputs_to_disk = True
|
args.cache_text_encoder_outputs_to_disk = True
|
||||||
@@ -57,8 +58,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
assert (
|
assert (
|
||||||
is_sdxl or args.weighted_captions is None
|
is_sdxl or args.weighted_captions is None
|
||||||
), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です"
|
), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です"
|
||||||
|
|
||||||
cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
|
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
use_user_config = args.dataset_config is not None
|
use_user_config = args.dataset_config is not None
|
||||||
@@ -178,7 +179,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
accelerator.print(f"Finished caching text encoder outputs to disk.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
@@ -188,9 +189,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
train_util.add_masked_loss_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
|
||||||
flux_train_utils.add_flux_train_arguments(parser)
|
flux_train_utils.add_flux_train_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||||
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
|
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -205,6 +207,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
|
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
|
||||||
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
|
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weighted_captions",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user