load images in parallel when caching latents

This commit is contained in:
kohya-ss
2024-10-13 18:22:19 +09:00
parent 74228c9953
commit 2244cf5b83

View File

@@ -3,6 +3,7 @@
import argparse import argparse
import ast import ast
import asyncio import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
import datetime import datetime
import importlib import importlib
import json import json
@@ -1058,7 +1059,6 @@ class BaseDataset(torch.utils.data.Dataset):
and self.random_crop == other.random_crop and self.random_crop == other.random_crop
) )
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = [] batch: List[ImageInfo] = []
current_condition = None current_condition = None
@@ -1066,57 +1066,70 @@ class BaseDataset(torch.utils.data.Dataset):
num_processes = accelerator.num_processes num_processes = accelerator.num_processes
process_index = accelerator.process_index process_index = accelerator.process_index
logger.info("checking cache validity...") # define a function to submit a batch to cache
for i, info in enumerate(tqdm(image_infos)): def submit_batch(batch, cond):
subset = self.image_to_subset[info.image_key] for info in batch:
if info.image is not None and isinstance(info.image, Future):
info.image = info.image.result() # future to image
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)
if info.latents_npz is not None: # fine tuning dataset # define ThreadPoolExecutor to load images in parallel
continue max_workers = min(os.cpu_count(), len(image_infos))
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size
executor = ThreadPoolExecutor(max_workers)
# check disk cache exists and size of latents try:
if caching_strategy.cache_to_disk: # iterate images
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix logger.info("caching latents...")
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
# if the modulo of num_processes is not equal to process_index, skip caching if info.latents_npz is not None: # fine tuning dataset
# this makes each process cache different latents
if i % num_processes != process_index:
continue continue
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") # check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
cache_available = caching_strategy.is_disk_cached_latents_expected( # if the modulo of num_processes is not equal to process_index, skip caching
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask # this makes each process cache different latents
) if i % num_processes != process_index:
if cache_available: # do not add to batch continue
continue
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
batch.append(info) cache_available = caching_strategy.is_disk_cached_latents_expected(
current_condition = condition info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
# if number of data in batch is enough, flush the batch # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
if len(batch) >= caching_strategy.batch_size: condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
batches.append((current_condition, batch)) if len(batch) > 0 and current_condition != condition:
batch = [] submit_batch(batch, current_condition)
current_condition = None batch = []
if len(batch) > 0: if info.image is None:
batches.append((current_condition, batch)) # load image in parallel
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
if len(batches) == 0: batch.append(info)
logger.info("no latents to cache") current_condition = condition
return
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded # if number of data in batch is enough, flush the batch
logger.info("caching latents...") if len(batch) >= caching_strategy.batch_size:
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): submit_batch(batch, current_condition)
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) batch = []
current_condition = None
if len(batch) > 0:
submit_batch(batch, current_condition)
finally:
executor.shutdown()
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと