mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
load images in parallel when caching latents
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
import datetime
|
import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
@@ -1058,7 +1059,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
and self.random_crop == other.random_crop
|
and self.random_crop == other.random_crop
|
||||||
)
|
)
|
||||||
|
|
||||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
|
||||||
batch: List[ImageInfo] = []
|
batch: List[ImageInfo] = []
|
||||||
current_condition = None
|
current_condition = None
|
||||||
|
|
||||||
@@ -1066,57 +1066,70 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
num_processes = accelerator.num_processes
|
num_processes = accelerator.num_processes
|
||||||
process_index = accelerator.process_index
|
process_index = accelerator.process_index
|
||||||
|
|
||||||
logger.info("checking cache validity...")
|
# define a function to submit a batch to cache
|
||||||
for i, info in enumerate(tqdm(image_infos)):
|
def submit_batch(batch, cond):
|
||||||
subset = self.image_to_subset[info.image_key]
|
for info in batch:
|
||||||
|
if info.image is not None and isinstance(info.image, Future):
|
||||||
|
info.image = info.image.result() # future to image
|
||||||
|
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)
|
||||||
|
|
||||||
if info.latents_npz is not None: # fine tuning dataset
|
# define ThreadPoolExecutor to load images in parallel
|
||||||
continue
|
max_workers = min(os.cpu_count(), len(image_infos))
|
||||||
|
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
|
||||||
|
max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size
|
||||||
|
executor = ThreadPoolExecutor(max_workers)
|
||||||
|
|
||||||
# check disk cache exists and size of latents
|
try:
|
||||||
if caching_strategy.cache_to_disk:
|
# iterate images
|
||||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
logger.info("caching latents...")
|
||||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
for i, info in enumerate(tqdm(image_infos)):
|
||||||
|
subset = self.image_to_subset[info.image_key]
|
||||||
|
|
||||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
if info.latents_npz is not None: # fine tuning dataset
|
||||||
# this makes each process cache different latents
|
|
||||||
if i % num_processes != process_index:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
# check disk cache exists and size of latents
|
||||||
|
if caching_strategy.cache_to_disk:
|
||||||
|
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||||
|
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||||
|
|
||||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
# this makes each process cache different latents
|
||||||
)
|
if i % num_processes != process_index:
|
||||||
if cache_available: # do not add to batch
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
|
||||||
if len(batch) > 0 and current_condition != condition:
|
|
||||||
batches.append((current_condition, batch))
|
|
||||||
batch = []
|
|
||||||
|
|
||||||
batch.append(info)
|
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||||
current_condition = condition
|
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||||
|
)
|
||||||
|
if cache_available: # do not add to batch
|
||||||
|
continue
|
||||||
|
|
||||||
# if number of data in batch is enough, flush the batch
|
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||||
if len(batch) >= caching_strategy.batch_size:
|
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
batches.append((current_condition, batch))
|
if len(batch) > 0 and current_condition != condition:
|
||||||
batch = []
|
submit_batch(batch, current_condition)
|
||||||
current_condition = None
|
batch = []
|
||||||
|
|
||||||
if len(batch) > 0:
|
if info.image is None:
|
||||||
batches.append((current_condition, batch))
|
# load image in parallel
|
||||||
|
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
|
||||||
|
|
||||||
if len(batches) == 0:
|
batch.append(info)
|
||||||
logger.info("no latents to cache")
|
current_condition = condition
|
||||||
return
|
|
||||||
|
|
||||||
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
# if number of data in batch is enough, flush the batch
|
||||||
logger.info("caching latents...")
|
if len(batch) >= caching_strategy.batch_size:
|
||||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
submit_batch(batch, current_condition)
|
||||||
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
batch = []
|
||||||
|
current_condition = None
|
||||||
|
|
||||||
|
if len(batch) > 0:
|
||||||
|
submit_batch(batch, current_condition)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
|
|||||||
Reference in New Issue
Block a user