Refactor caching in train scripts

This commit is contained in:
kohya-ss
2024-10-12 20:18:41 +09:00
parent ff4083b910
commit c80c304779
14 changed files with 95 additions and 47 deletions

View File

@@ -11,6 +11,16 @@ The command to install PyTorch is as follows:
### Recent Updates
Oct 12, 2024 (update 1):
- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU.
- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching.
- `--skip_cache_check` option is added to each training script.
- When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped.
- Specify this option if you have a large number of cache files and the consistency check takes time.
- Even if this option is specified, the cache will be created if the file does not exist.
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
Oct 12, 2024:
- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!

View File

@@ -59,7 +59,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

View File

@@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -81,7 +85,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -142,7 +146,7 @@ def train(args):
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
@@ -181,7 +185,7 @@ def train(args):
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
@@ -229,7 +233,7 @@ def train(args):
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
@@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",

View File

@@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
@@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:

View File

@@ -31,6 +31,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
@@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset):
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
"""
@@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset):
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
# check disk cache exists and size of latents
# support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index
logger.info("checking cache validity...")
for i, info in enumerate(tqdm(image_infos)):
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available or not is_main_process: # do not add to batch
if cache_available: # do not add to batch
continue
batch.append(info)
@@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(model, accelerator)
accelerator.wait_for_everyone()
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_text_encoder_outputs(models, is_main_process)
dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
@@ -4210,6 +4223,12 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument(
"--enable_bucket",
action="store_true",
@@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace):
dynamo_backend = args.dynamo_backend
kwargs_handlers = [
InitProcessGroupKwargs(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
) if torch.cuda.device_count() > 1 else None,
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
(
InitProcessGroupKwargs(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method=(
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
),
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
)
if torch.cuda.device_count() > 1
else None
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

View File

@@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -103,7 +107,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -312,7 +316,7 @@ def train(args):
text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
args.skip_cache_check,
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
@@ -325,7 +329,7 @@ def train(args):
t5xxl.to(t5xxl_device, dtype=t5xxl_dtype)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
@@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip cache (latents and text encoder outputs) check / キャッシュlatentsとtext encoder outputsのチェックをスキップする",
)
parser.add_argument(
"--num_last_block_to_freeze",

View File

@@ -131,7 +131,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -328,7 +328,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()

View File

@@ -84,7 +84,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -230,7 +230,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()

View File

@@ -93,7 +93,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -202,7 +202,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
@@ -431,7 +431,6 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached

View File

@@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
@@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
return None
@@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
)
dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU

View File

@@ -49,7 +49,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy

View File

@@ -64,7 +64,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

View File

@@ -116,7 +116,7 @@ class NetworkTrainer:
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy

View File

@@ -114,7 +114,7 @@ class TextualInversionTrainer:
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy