diff --git a/README.md b/README.md index 37fc911f..2b256283 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. +- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. +- `--skip_cache_check` option is added to each training script. + - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. + - Specify this option if you have a large number of cache files and the consistency check takes time. + - Even if this option is specified, the cache will be created if the file does not exist. + - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/fine_tune.py b/fine_tune.py index fd63385b..cdc005d9 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/flux_train.py b/flux_train.py index ecc87c0a..e18a9244 100644 --- a/flux_train.py +++ b/flux_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -81,7 +85,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -142,7 +146,7 @@ def train(args): if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False ) ) t5xxl_max_token_length = ( @@ -181,7 +185,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -229,7 +233,7 @@ def train(args): strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) parser.add_argument( "--blocks_to_swap", diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28..3bd8316d 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, - None, - False, + args.text_encoder_batch_size, + args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) @@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders[1].to(weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) # cache sample prompts if args.sample_prompts is not None: diff --git a/library/train_util.py b/library/train_util.py index 67eaae41..4e6b3408 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import hashlib import subprocess from io import BytesIO import toml + # from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): r""" a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. """ @@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset): # split by resolution batches = [] batch = [] - logger.info("checking cache validity...") - for info in tqdm(image_infos): - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - # check disk cache exists and size of latents + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available or not is_main_process: # do not add to batch + if cache_available: # do not add to batch continue batch.append(info) @@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset): tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_text_encoder_outputs(models, is_main_process) + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() def set_caching_mode(self, caching_mode): for dataset in self.datasets: @@ -4210,6 +4223,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", @@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend = args.dynamo_backend kwargs_handlers = [ - InitProcessGroupKwargs( - backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, - timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None - ) if torch.cuda.device_count() > 1 else None, - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, - static_graph=args.ddp_static_graph - ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ] kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) diff --git a/sd3_train.py b/sd3_train.py index 5120105f..7290956a 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -103,7 +107,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -312,7 +316,7 @@ def train(args): text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, - False, + args.skip_cache_check, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, args.apply_lg_attn_mask, args.apply_t5_attn_mask, @@ -325,7 +329,7 @@ def train(args): t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", ) parser.add_argument( "--num_last_block_to_freeze", diff --git a/sdxl_train.py b/sdxl_train.py index aeff9c46..9b2d1916 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -131,7 +131,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -328,7 +328,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 67c8d52c..74b3a64a 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -84,7 +84,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -230,7 +230,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9d1cfc63..14ff7c24 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -93,7 +93,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -202,7 +202,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() @@ -431,7 +431,6 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Text Encoder outputs are cached diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 20e32155..4a16a489 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy @@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions ) else: return None @@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs( - text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index cbfcef55..821a6955 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -49,7 +49,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_db.py b/train_db.py index e49a7e70..683b4233 100644 --- a/train_db.py +++ b/train_db.py @@ -64,7 +64,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/train_network.py b/train_network.py index 7437157b..d5330aef 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ class NetworkTrainer: def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3b3d3393..4d8a3abb 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -114,7 +114,7 @@ class TextualInversionTrainer: def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy