mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
1 Commits
v0.7.0
...
multi_embe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8629e3c1a |
15
README.md
15
README.md
@@ -249,21 +249,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Oct 11, 2023 / 2023/10/11
|
||||
- Fix to work `make_captions_by_git.py` with the latest version of transformers.
|
||||
- Improve `gen_img_diffusers.py` and `sdxl_gen_img.py`. Both scripts now support the following options:
|
||||
- `--network_merge_n_models` option can be used to merge some of the models. The remaining models aren't merged, so the multiplier can be changed, and the regional LoRA also works.
|
||||
- `--network_regional_mask_max_color_codes` is added. Now you can use up to 7 regions.
|
||||
- When this option is specified, the mask of the regional LoRA is the color code based instead of the channel based. The value is the maximum number of the color codes (up to 7).
|
||||
- You can specify the mask for each LoRA by colors: 0x0000ff, 0x00ff00, 0x00ffff, 0xff0000, 0xff00ff, 0xffff00, 0xffffff.
|
||||
|
||||
- `make_captions_by_git.py` が最新の transformers で動作するように修正しました。
|
||||
- `gen_img_diffusers.py` と `sdxl_gen_img.py` を更新し、以下のオプションを追加しました。
|
||||
- `--network_merge_n_models` オプションで一部のモデルのみマージできます。残りのモデルはマージされないため、重みを変更したり、領域別LoRAを使用したりできます。
|
||||
- `--network_regional_mask_max_color_codes` を追加しました。最大7つの領域を使用できます。
|
||||
- このオプションを指定すると、領域別LoRAのマスクはチャンネルベースではなくカラーコードベースになります。値はカラーコードの最大数(最大7)です。
|
||||
- 各LoRAに対してマスクをカラーで指定できます:0x0000ff、0x00ff00、0x00ffff、0xff0000、0xff00ff、0xffff00、0xffffff。
|
||||
|
||||
### Oct 9. 2023 / 2023/10/9
|
||||
|
||||
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
|
||||
|
||||
@@ -52,9 +52,6 @@ def collate_fn_remove_corrupted(batch):
|
||||
|
||||
|
||||
def main(args):
|
||||
r"""
|
||||
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
|
||||
|
||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||
@@ -68,7 +65,6 @@ def main(args):
|
||||
return input_ids
|
||||
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
"""
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
@@ -85,7 +81,7 @@ def main(args):
|
||||
def run_batch(path_imgs):
|
||||
imgs = [im for _, im in path_imgs]
|
||||
|
||||
# curr_batch_size[0] = len(path_imgs)
|
||||
curr_batch_size[0] = len(path_imgs)
|
||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
@@ -65,13 +65,10 @@ import re
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -957,7 +954,7 @@ class PipelineLike:
|
||||
text_emb_last = torch.stack(text_emb_last)
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
@@ -2366,19 +2363,12 @@ def main(args):
|
||||
network_default_muls = []
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
# merge関連の引数を統合する
|
||||
if args.network_merge:
|
||||
network_merge = len(args.network_module) # all networks are merged
|
||||
elif args.network_merge_n_models:
|
||||
network_merge = args.network_merge_n_models
|
||||
else:
|
||||
network_merge = 0
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -2389,32 +2379,31 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
if network_merge and not mergeable:
|
||||
if args.network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not mergeable or i >= network_merge:
|
||||
# not merging
|
||||
if not args.network_merge or not mergeable:
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -2428,7 +2417,6 @@ def main(args):
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
network_default_muls.append(network_mul)
|
||||
else:
|
||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||
|
||||
@@ -2724,18 +2712,9 @@ def main(args):
|
||||
|
||||
size = None
|
||||
for i, network in enumerate(networks):
|
||||
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
|
||||
if i < 3:
|
||||
np_mask = np.array(mask_images[0])
|
||||
|
||||
if args.network_regional_mask_max_color_codes:
|
||||
# カラーコードでマスクを指定する
|
||||
ch0 = (i + 1) & 1
|
||||
ch1 = ((i + 1) >> 1) & 1
|
||||
ch2 = ((i + 1) >> 2) & 1
|
||||
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||
np_mask = np_mask.astype(np.uint8) * 255
|
||||
else:
|
||||
np_mask = np_mask[:, :, i]
|
||||
np_mask = np_mask[:, :, i]
|
||||
size = np_mask.shape
|
||||
else:
|
||||
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||
@@ -3388,19 +3367,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_regional_mask_max_color_codes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
|
||||
@@ -17,13 +17,10 @@ import re
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1537,20 +1534,12 @@ def main(args):
|
||||
network_default_muls = []
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
# merge関連の引数を統合する
|
||||
if args.network_merge:
|
||||
network_merge = len(args.network_module) # all networks are merged
|
||||
elif args.network_merge_n_models:
|
||||
network_merge = args.network_merge_n_models
|
||||
else:
|
||||
network_merge = 0
|
||||
print(f"network_merge: {network_merge}")
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -1561,32 +1550,31 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
if network_merge and not mergeable:
|
||||
if args.network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not mergeable or i >= network_merge:
|
||||
# not merging
|
||||
if not args.network_merge or not mergeable:
|
||||
network.apply_to([text_encoder1, text_encoder2], unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -1600,7 +1588,6 @@ def main(args):
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
network_default_muls.append(network_mul)
|
||||
else:
|
||||
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
|
||||
|
||||
@@ -1877,18 +1864,9 @@ def main(args):
|
||||
|
||||
size = None
|
||||
for i, network in enumerate(networks):
|
||||
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
|
||||
if i < 3:
|
||||
np_mask = np.array(mask_images[0])
|
||||
|
||||
if args.network_regional_mask_max_color_codes:
|
||||
# カラーコードでマスクを指定する
|
||||
ch0 = (i + 1) & 1
|
||||
ch1 = ((i + 1) >> 1) & 1
|
||||
ch2 = ((i + 1) >> 2) & 1
|
||||
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||
np_mask = np_mask.astype(np.uint8) * 255
|
||||
else:
|
||||
np_mask = np_mask[:, :, i]
|
||||
np_mask = np_mask[:, :, i]
|
||||
size = np_mask.shape
|
||||
else:
|
||||
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||
@@ -2637,19 +2615,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_regional_mask_max_color_codes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
|
||||
84
tools/split_ti_embeddings.py
Normal file
84
tools/split_ti_embeddings.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def split(args):
|
||||
# load embedding
|
||||
if args.embedding.endswith(".safetensors"):
|
||||
embedding = load_file(args.embedding)
|
||||
with safe_open(args.embedding, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
embedding = torch.load(args.embedding)
|
||||
metadata = None
|
||||
|
||||
# check format
|
||||
if "emb_params" in embedding:
|
||||
# SD1/2
|
||||
keys = ["emb_params"]
|
||||
elif "clip_l" in embedding:
|
||||
# SDXL
|
||||
keys = ["clip_l", "clip_g"]
|
||||
else:
|
||||
print("Unknown embedding format")
|
||||
exit()
|
||||
num_vectors = embedding[keys[0]].shape[0]
|
||||
|
||||
# prepare output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# prepare splits
|
||||
if args.vectors_per_split is not None:
|
||||
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
|
||||
vectors_for_split = [args.vectors_per_split] * num_splits
|
||||
if sum(vectors_for_split) > num_vectors:
|
||||
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
|
||||
assert sum(vectors_for_split) == num_vectors
|
||||
elif args.vectors is not None:
|
||||
vectors_for_split = args.vectors
|
||||
num_splits = len(vectors_for_split)
|
||||
else:
|
||||
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
|
||||
exit()
|
||||
|
||||
assert (
|
||||
sum(vectors_for_split) == num_vectors
|
||||
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
|
||||
|
||||
# split
|
||||
basename = os.path.splitext(os.path.basename(args.embedding))[0]
|
||||
done_vectors = 0
|
||||
for i, num_vectors in enumerate(vectors_for_split):
|
||||
print(f"Splitting {num_vectors} vectors...")
|
||||
|
||||
split_embedding = {}
|
||||
for key in keys:
|
||||
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
|
||||
|
||||
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
|
||||
save_file(split_embedding, output_file, metadata)
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
done_vectors += num_vectors
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Merge models")
|
||||
parser.add_argument("--embedding", type=str, help="Embedding to split")
|
||||
parser.add_argument("--output_dir", type=str, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--vectors_per_split",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
|
||||
)
|
||||
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
|
||||
args = parser.parse_args()
|
||||
split(args)
|
||||
@@ -7,10 +7,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -167,6 +170,13 @@ class TextualInversionTrainer:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
|
||||
assert (
|
||||
args.token_string is not None or args.token_strings is not None
|
||||
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
|
||||
assert (
|
||||
not use_template or args.token_strings is None
|
||||
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
@@ -215,9 +225,17 @@ class TextualInversionTrainer:
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
||||
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
if args.token_strings is not None:
|
||||
token_strings = args.token_strings
|
||||
assert (
|
||||
len(token_strings) == args.num_vectors_per_token
|
||||
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
|
||||
for token_string in token_strings:
|
||||
self.assert_token_string(token_string, tokenizers)
|
||||
else:
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
token_ids_list = []
|
||||
token_embeds_list = []
|
||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||
@@ -332,7 +350,7 @@ class TextualInversionTrainer:
|
||||
prompt_replacement = None
|
||||
else:
|
||||
# サンプル生成用
|
||||
if args.num_vectors_per_token > 1:
|
||||
if args.num_vectors_per_token > 1 and args.token_strings is None:
|
||||
replace_to = " ".join(token_strings)
|
||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
@@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_strings",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
|
||||
)
|
||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument(
|
||||
"--use_object_template",
|
||||
|
||||
Reference in New Issue
Block a user