load images in parallel when caching latents

This commit is contained in:
kohya-ss
2024-10-13 18:22:19 +09:00
parent 74228c9953
commit 2244cf5b83

View File

@@ -3,6 +3,7 @@
import argparse
import ast
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
import datetime
import importlib
import json
@@ -1058,7 +1059,6 @@ class BaseDataset(torch.utils.data.Dataset):
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None
@@ -1066,7 +1066,22 @@ class BaseDataset(torch.utils.data.Dataset):
num_processes = accelerator.num_processes
process_index = accelerator.process_index
logger.info("checking cache validity...")
# define a function to submit a batch to cache
def submit_batch(batch, cond):
for info in batch:
if info.image is not None and isinstance(info.image, Future):
info.image = info.image.result() # future to image
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)
# define ThreadPoolExecutor to load images in parallel
max_workers = min(os.cpu_count(), len(image_infos))
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size
executor = ThreadPoolExecutor(max_workers)
try:
# iterate images
logger.info("caching latents...")
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
@@ -1094,29 +1109,27 @@ class BaseDataset(torch.utils.data.Dataset):
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
submit_batch(batch, current_condition)
batch = []
if info.image is None:
# load image in parallel
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
batches.append((current_condition, batch))
submit_batch(batch, current_condition)
batch = []
current_condition = None
if len(batch) > 0:
batches.append((current_condition, batch))
submit_batch(batch, current_condition)
if len(batches) == 0:
logger.info("no latents to cache")
return
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
finally:
executor.shutdown()
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと