diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index e2cebe8d..00f847a1 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -130,14 +130,16 @@ def main(args): latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) for (image_key, reso, _), latent in zip(bucket, latents): - np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent) + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key + np.savez(os.path.join(args.train_data_dir, npz_file_name), latent) # flip if args.flip_aug: latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない for (image_key, reso, _), latent in zip(bucket, latents): - np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent) + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key + np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent) bucket.clear()