mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
don't hold latent on memory for finetuning dataset
This commit is contained in:
@@ -91,6 +91,7 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".
|
||||
|
||||
try:
|
||||
import pillow_avif
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
|
||||
except:
|
||||
pass
|
||||
@@ -853,16 +854,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
for info in image_infos:
|
||||
print("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None:
|
||||
info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False)
|
||||
info.latents = torch.FloatTensor(info.latents)
|
||||
|
||||
info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None
|
||||
if info.latents_flipped is not None:
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
|
||||
Reference in New Issue
Block a user