mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Compare commits
1 Commits
v0.8.6
...
multi_embe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8629e3c1a |
84
tools/split_ti_embeddings.py
Normal file
84
tools/split_ti_embeddings.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def split(args):
|
||||
# load embedding
|
||||
if args.embedding.endswith(".safetensors"):
|
||||
embedding = load_file(args.embedding)
|
||||
with safe_open(args.embedding, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
embedding = torch.load(args.embedding)
|
||||
metadata = None
|
||||
|
||||
# check format
|
||||
if "emb_params" in embedding:
|
||||
# SD1/2
|
||||
keys = ["emb_params"]
|
||||
elif "clip_l" in embedding:
|
||||
# SDXL
|
||||
keys = ["clip_l", "clip_g"]
|
||||
else:
|
||||
print("Unknown embedding format")
|
||||
exit()
|
||||
num_vectors = embedding[keys[0]].shape[0]
|
||||
|
||||
# prepare output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# prepare splits
|
||||
if args.vectors_per_split is not None:
|
||||
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
|
||||
vectors_for_split = [args.vectors_per_split] * num_splits
|
||||
if sum(vectors_for_split) > num_vectors:
|
||||
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
|
||||
assert sum(vectors_for_split) == num_vectors
|
||||
elif args.vectors is not None:
|
||||
vectors_for_split = args.vectors
|
||||
num_splits = len(vectors_for_split)
|
||||
else:
|
||||
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
|
||||
exit()
|
||||
|
||||
assert (
|
||||
sum(vectors_for_split) == num_vectors
|
||||
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
|
||||
|
||||
# split
|
||||
basename = os.path.splitext(os.path.basename(args.embedding))[0]
|
||||
done_vectors = 0
|
||||
for i, num_vectors in enumerate(vectors_for_split):
|
||||
print(f"Splitting {num_vectors} vectors...")
|
||||
|
||||
split_embedding = {}
|
||||
for key in keys:
|
||||
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
|
||||
|
||||
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
|
||||
save_file(split_embedding, output_file, metadata)
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
done_vectors += num_vectors
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Merge models")
|
||||
parser.add_argument("--embedding", type=str, help="Embedding to split")
|
||||
parser.add_argument("--output_dir", type=str, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--vectors_per_split",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
|
||||
)
|
||||
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
|
||||
args = parser.parse_args()
|
||||
split(args)
|
||||
@@ -7,10 +7,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -167,6 +170,13 @@ class TextualInversionTrainer:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
|
||||
assert (
|
||||
args.token_string is not None or args.token_strings is not None
|
||||
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
|
||||
assert (
|
||||
not use_template or args.token_strings is None
|
||||
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
@@ -215,9 +225,17 @@ class TextualInversionTrainer:
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
||||
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
if args.token_strings is not None:
|
||||
token_strings = args.token_strings
|
||||
assert (
|
||||
len(token_strings) == args.num_vectors_per_token
|
||||
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
|
||||
for token_string in token_strings:
|
||||
self.assert_token_string(token_string, tokenizers)
|
||||
else:
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
token_ids_list = []
|
||||
token_embeds_list = []
|
||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||
@@ -332,7 +350,7 @@ class TextualInversionTrainer:
|
||||
prompt_replacement = None
|
||||
else:
|
||||
# サンプル生成用
|
||||
if args.num_vectors_per_token > 1:
|
||||
if args.num_vectors_per_token > 1 and args.token_strings is None:
|
||||
replace_to = " ".join(token_strings)
|
||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
@@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_strings",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
|
||||
)
|
||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument(
|
||||
"--use_object_template",
|
||||
|
||||
Reference in New Issue
Block a user