mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
import argparse
|
|
import os
|
|
|
|
import torch
|
|
from safetensors import safe_open
|
|
from safetensors.torch import load_file, save_file
|
|
from tqdm import tqdm
|
|
|
|
|
|
def split(args):
|
|
# load embedding
|
|
if args.embedding.endswith(".safetensors"):
|
|
embedding = load_file(args.embedding)
|
|
with safe_open(args.embedding, framework="pt") as f:
|
|
metadata = f.metadata()
|
|
else:
|
|
embedding = torch.load(args.embedding)
|
|
metadata = None
|
|
|
|
# check format
|
|
if "emb_params" in embedding:
|
|
# SD1/2
|
|
keys = ["emb_params"]
|
|
elif "clip_l" in embedding:
|
|
# SDXL
|
|
keys = ["clip_l", "clip_g"]
|
|
else:
|
|
print("Unknown embedding format")
|
|
exit()
|
|
num_vectors = embedding[keys[0]].shape[0]
|
|
|
|
# prepare output directory
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# prepare splits
|
|
if args.vectors_per_split is not None:
|
|
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
|
|
vectors_for_split = [args.vectors_per_split] * num_splits
|
|
if sum(vectors_for_split) > num_vectors:
|
|
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
|
|
assert sum(vectors_for_split) == num_vectors
|
|
elif args.vectors is not None:
|
|
vectors_for_split = args.vectors
|
|
num_splits = len(vectors_for_split)
|
|
else:
|
|
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
|
|
exit()
|
|
|
|
assert (
|
|
sum(vectors_for_split) == num_vectors
|
|
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
|
|
|
|
# split
|
|
basename = os.path.splitext(os.path.basename(args.embedding))[0]
|
|
done_vectors = 0
|
|
for i, num_vectors in enumerate(vectors_for_split):
|
|
print(f"Splitting {num_vectors} vectors...")
|
|
|
|
split_embedding = {}
|
|
for key in keys:
|
|
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
|
|
|
|
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
|
|
save_file(split_embedding, output_file, metadata)
|
|
print(f"Saved to {output_file}")
|
|
|
|
done_vectors += num_vectors
|
|
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Merge models")
|
|
parser.add_argument("--embedding", type=str, help="Embedding to split")
|
|
parser.add_argument("--output_dir", type=str, help="Output directory")
|
|
parser.add_argument(
|
|
"--vectors_per_split",
|
|
type=int,
|
|
default=None,
|
|
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
|
|
)
|
|
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
|
|
args = parser.parse_args()
|
|
split(args)
|