don't hold latent on memory for finetuning dataset

This commit is contained in:
Kohya S
2023-07-10 08:46:15 +09:00
parent 5c80117fbd
commit b6e328ea8f

View File

@@ -91,6 +91,7 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".
try:
import pillow_avif
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
except:
pass
@@ -853,16 +854,11 @@ class BaseDataset(torch.utils.data.Dataset):
# split by resolution
batches = []
batch = []
for info in image_infos:
print("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None:
info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None
if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents