cache latents to disk in dreambooth method

This commit is contained in:
Kohya S
2023-04-12 23:10:39 +09:00
parent 5050971ac6
commit 2e9f7b5f91
6 changed files with 67 additions and 15 deletions

View File

@@ -142,12 +142,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:

View File

@@ -722,7 +722,7 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae, vae_batch_size=1):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# ちょっと速くした
print("caching latents.")
@@ -740,11 +740,38 @@ class BaseDataset(torch.utils.data.Dataset):
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
# might be None, but that's ok because check is done in dataset
info.latents_flipped = self.load_latents_from_npz(info, True)
if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
# check disk cache exists and size of latents
if cache_to_disk:
# TODO: refactor to unify with FineTuningDataset
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
if not is_main_process:
continue
cache_available = False
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
if os.path.exists(info.latents_npz):
cached_latents = np.load(info.latents_npz)
if cached_latents["latents"].shape[1:3] == expected_latents_size:
cache_available = True
if subset.flip_aug:
cache_available = False
if os.path.exists(info.latents_npz_flipped):
cached_latents_flipped = np.load(info.latents_npz_flipped)
if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size:
cache_available = True
if cache_available:
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
@@ -760,6 +787,9 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0:
batches.append(batch)
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
return
# iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
@@ -773,14 +803,21 @@ class BaseDataset(torch.utils.data.Dataset):
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent
if cache_to_disk:
np.savez(info.latents_npz, latent.float().numpy())
else:
info.latents = latent
if subset.flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents_flipped = latent
if cache_to_disk:
np.savez(info.latents_npz_flipped, latent.float().numpy())
else:
info.latents_flipped = latent
def get_image_size(self, image_path):
image = Image.open(image_path)
@@ -873,10 +910,10 @@ class BaseDataset(torch.utils.data.Dataset):
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
# image/latentsを処理する
if image_info.latents is not None:
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None:
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents)
image = None
@@ -1340,10 +1377,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size)
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -2144,9 +2181,14 @@ def add_dataset_arguments(
parser.add_argument(
"--cache_latents",
action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可",
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheするaugmentationは使用不可 ",
)
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--cache_latents_to_disk",
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
)
@@ -3203,4 +3245,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]

View File

@@ -117,12 +117,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加

View File

@@ -172,12 +172,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# prepare network
import sys

View File

@@ -233,12 +233,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()

View File

@@ -267,12 +267,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()