mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
load images in parallel when caching latents
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
@@ -1058,7 +1059,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
@@ -1066,57 +1066,70 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
num_processes = accelerator.num_processes
|
||||
process_index = accelerator.process_index
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
# define a function to submit a batch to cache
|
||||
def submit_batch(batch, cond):
|
||||
for info in batch:
|
||||
if info.image is not None and isinstance(info.image, Future):
|
||||
info.image = info.image.result() # future to image
|
||||
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
# define ThreadPoolExecutor to load images in parallel
|
||||
max_workers = min(os.cpu_count(), len(image_infos))
|
||||
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
|
||||
max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size
|
||||
executor = ThreadPoolExecutor(max_workers)
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
try:
|
||||
# iterate images
|
||||
logger.info("caching latents...")
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
if i % num_processes != process_index:
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append((current_condition, batch))
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
|
||||
|
||||
if len(batches) == 0:
|
||||
logger.info("no latents to cache")
|
||||
return
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
|
||||
finally:
|
||||
executor.shutdown()
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
|
||||
Reference in New Issue
Block a user