Fix npz file name for images with dots #12

This commit is contained in:
Kohya S
2022-12-24 21:23:40 +09:00
parent 3800e145bd
commit da05ad6339

View File

@@ -130,14 +130,16 @@ def main(args):
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, reso, _), latent in zip(bucket, latents):
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent)
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent)
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
bucket.clear()