mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Refactor caching in train scripts
This commit is contained in:
10
README.md
10
README.md
@@ -11,6 +11,16 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Oct 12, 2024 (update 1):
|
||||
|
||||
- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU.
|
||||
- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching.
|
||||
- `--skip_cache_check` option is added to each training script.
|
||||
- When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped.
|
||||
- Specify this option if you have a large number of cache files and the consistency check takes time.
|
||||
- Even if this option is specified, the cache will be created if the file does not exist.
|
||||
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
|
||||
|
||||
Oct 12, 2024:
|
||||
|
||||
- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!
|
||||
|
||||
@@ -59,7 +59,7 @@ def train(args):
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if cache_latents:
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
|
||||
@@ -57,6 +57,10 @@ def train(args):
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# temporary: backward compatibility for deprecated options. remove in the future
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
@@ -81,7 +85,7 @@ def train(args):
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
@@ -142,7 +146,7 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
)
|
||||
t5xxl_max_token_length = (
|
||||
@@ -181,7 +185,7 @@ def train(args):
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
@@ -229,7 +233,7 @@ def train(args):
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process)
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
@@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap",
|
||||
|
||||
@@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
False,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
@@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[1].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
|
||||
@@ -31,6 +31,7 @@ import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
|
||||
r"""
|
||||
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
|
||||
"""
|
||||
@@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
# support multiple-gpus
|
||||
num_processes = accelerator.num_processes
|
||||
process_index = accelerator.process_index
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
# check disk cache exists and size of text encoder outputs
|
||||
if caching_strategy.cache_to_disk:
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different text encoder outputs
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available or not is_main_process: # do not add to batch
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
@@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(model, accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
@@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||
)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
dataset.new_cache_text_encoder_outputs(models, accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
@@ -4210,6 +4223,12 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_cache_check",
|
||||
action="store_true",
|
||||
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
|
||||
" / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_bucket",
|
||||
action="store_true",
|
||||
@@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
kwargs_handlers = [
|
||||
InitProcessGroupKwargs(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
|
||||
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
|
||||
) if torch.cuda.device_count() > 1 else None,
|
||||
DistributedDataParallelKwargs(
|
||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
|
||||
static_graph=args.ddp_static_graph
|
||||
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
|
||||
(
|
||||
InitProcessGroupKwargs(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method=(
|
||||
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
|
||||
),
|
||||
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
|
||||
)
|
||||
if torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
(
|
||||
DistributedDataParallelKwargs(
|
||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
|
||||
)
|
||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||
else None
|
||||
),
|
||||
]
|
||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
17
sd3_train.py
17
sd3_train.py
@@ -57,6 +57,10 @@ def train(args):
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# temporary: backward compatibility for deprecated options. remove in the future
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
@@ -103,7 +107,7 @@ def train(args):
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
@@ -312,7 +316,7 @@ def train(args):
|
||||
text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
args.skip_cache_check,
|
||||
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
|
||||
args.apply_lg_attn_mask,
|
||||
args.apply_t5_attn_mask,
|
||||
@@ -325,7 +329,7 @@ def train(args):
|
||||
t5xxl.to(t5xxl_device, dtype=t5xxl_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
@@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_cache_check",
|
||||
action="store_true",
|
||||
help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_last_block_to_freeze",
|
||||
|
||||
@@ -131,7 +131,7 @@ def train(args):
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
@@ -328,7 +328,7 @@ def train(args):
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
|
||||
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def train(args):
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
@@ -230,7 +230,7 @@ def train(args):
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
|
||||
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ def train(args):
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
@@ -202,7 +202,7 @@ def train(args):
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
|
||||
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -431,7 +431,6 @@ def train(args):
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Text Encoder outputs are cached
|
||||
|
||||
@@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
@@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(
|
||||
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
|
||||
)
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
|
||||
@@ -49,7 +49,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ def train(args):
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ class NetworkTrainer:
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class TextualInversionTrainer:
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
|
||||
Reference in New Issue
Block a user