From bfc3a65acda7f90abef9c16db279d2952f73fb77 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:08:16 +0900 Subject: [PATCH] fix to work cache latents/text encoder outputs --- library/train_util.py | 11 +++++++---- tools/cache_latents.py | 11 ++++++----- tools/cache_text_encoder_outputs.py | 18 +++++++++++++----- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1db470d6..92660926 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4064,15 +4064,18 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) if args.v_parameterization and not args.v2: logger.warning( diff --git a/tools/cache_latents.py b/tools/cache_latents.py index d8154ec3..e2faa58a 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ from accelerate.utils import set_seed import torch from tqdm import tqdm -from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -30,7 +30,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa else: is_schnell = False - if is_sd or is_sdxl: + if is_sd: tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) elif is_sdxl: tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -51,6 +51,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" args.cache_latents = True @@ -161,10 +162,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - parser.add_argument( - "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index d294d46c..7be9ad78 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -27,7 +27,7 @@ from library.config_util import ( BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments -from tools import cache_latents +from cache_latents import set_tokenize_strategy setup_logging() import logging @@ -38,6 +38,7 @@ logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs_to_disk = True @@ -57,8 +58,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: assert ( is_sdxl or args.weighted_captions is None ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" - - cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する use_user_config = args.dataset_config is not None @@ -178,7 +179,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -188,9 +189,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( @@ -205,6 +207,12 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) return parser