Compare commits

...

27 Commits

Author SHA1 Message Date
Kohya S
2a23713f71 Merge pull request #872 from kohya-ss/dev
fix make_captions_by_git, improve image generation scripts
2023-10-11 07:56:39 +09:00
Kohya S
681034d001 update readme 2023-10-11 07:54:30 +09:00
Kohya S
17813ff5b4 remove workaround for transfomers bs>1 close #869 2023-10-11 07:40:12 +09:00
Kohya S
3e81bd6b67 fix network_merge, add regional mask as color code 2023-10-09 23:07:14 +09:00
Kohya S
23ae358e0f Merge branch 'main' into dev 2023-10-09 21:42:13 +09:00
Kohya S
f611726364 add network_merge_n_models option 2023-10-09 21:41:50 +09:00
Kohya S
33ee0acd35 Merge pull request #867 from kohya-ss/dev
onnx support in wd14 tagger, OFT
2023-10-09 18:04:17 +09:00
Kohya S
8b79e3b06c fix typos 2023-10-09 18:00:45 +09:00
Kohya S
cf49e912fc update readme 2023-10-09 17:59:31 +09:00
Kohya S
66741c035c add OFT 2023-10-09 17:59:24 +09:00
Kohya S
406511c333 add error message if model.onnx doesn't exist 2023-10-09 17:08:58 +09:00
Kohya S
8a2d68d63e Merge pull request #864 from Isotr0py/onnx
Add `--onnx` to wd14 tagger
2023-10-09 15:14:11 +09:00
Kohya S
07d297fdbe Merge branch 'dev' into onnx 2023-10-09 15:13:40 +09:00
Kohya S
0d4e8b50d0 change option to append_tags, minor update 2023-10-09 15:09:54 +09:00
Kohya S
1d7c5c2a98 Merge pull request #858 from a-l-e-x-d-s-9/main
Add append_captions feature to wd14 tagger
2023-10-09 14:31:54 +09:00
Kohya S
0faa350175 Merge pull request #865 from kohya-ss/dev
Support JPEG-XL on windows, dropout for LyCORIS
2023-10-09 14:11:49 +09:00
Kohya S
8a7509db75 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-10-09 14:07:02 +09:00
Kohya S
025368f51c may work dropout in LyCORIS #859 2023-10-09 14:06:58 +09:00
Kohya S
5fe52ed322 Merge pull request #856 from Isotr0py/jxl
Fix JPEG-XL support
2023-10-09 13:55:03 +09:00
Kohya S
8b247a330b Merge pull request #851 from kohya-ss/dependabot/github_actions/actions/checkout-4
Bump actions/checkout from 3 to 4
2023-10-09 11:45:47 +09:00
Isotr0py
d6f458fcb3 fix dependency 2023-10-08 23:51:18 +08:00
Isotr0py
b8b84021e5 fix a typo 2023-10-08 20:49:03 +08:00
Isotr0py
70fe7e18be add onnx to wd14 tagger 2023-10-08 20:31:10 +08:00
alexds9
9378da3c82 Fix comment 2023-10-05 21:29:46 +03:00
alexds9
a4857fa764 Add append_captions feature to wd14 tagger
This feature allows for appending new tags to the existing content of caption files.
If the caption file for an image already exists, the tags generated from the current
run are appended to the existing ones. Duplicate tags are checked and avoided.
2023-10-05 21:26:09 +03:00
Isotr0py
592014923f Support JPEG-XL on windows 2023-10-04 21:48:25 +08:00
dependabot[bot]
6d06b215bf Bump actions/checkout from 3 to 4
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-10-01 22:51:32 +00:00
10 changed files with 692 additions and 62 deletions

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.16.15

View File

@@ -249,6 +249,55 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Oct 11, 2023 / 2023/10/11
- Fix to work `make_captions_by_git.py` with the latest version of transformers.
- Improve `gen_img_diffusers.py` and `sdxl_gen_img.py`. Both scripts now support the following options:
- `--network_merge_n_models` option can be used to merge some of the models. The remaining models aren't merged, so the multiplier can be changed, and the regional LoRA also works.
- `--network_regional_mask_max_color_codes` is added. Now you can use up to 7 regions.
- When this option is specified, the mask of the regional LoRA is the color code based instead of the channel based. The value is the maximum number of the color codes (up to 7).
- You can specify the mask for each LoRA by colors: 0x0000ff, 0x00ff00, 0x00ffff, 0xff0000, 0xff00ff, 0xffff00, 0xffffff.
- `make_captions_by_git.py` が最新の transformers で動作するように修正しました。
- `gen_img_diffusers.py``sdxl_gen_img.py` を更新し、以下のオプションを追加しました。
- `--network_merge_n_models` オプションで一部のモデルのみマージできます。残りのモデルはマージされないため、重みを変更したり、領域別LoRAを使用したりできます。
- `--network_regional_mask_max_color_codes` を追加しました。最大7つの領域を使用できます。
- このオプションを指定すると、領域別LoRAのマスクはチャンネルベースではなくカラーコードベースになります。値はカラーコードの最大数最大7です。
- 各LoRAに対してマスクをカラーで指定できます0x0000ff、0x00ff00、0x00ffff、0xff0000、0xff00ff、0xffff00、0xffffff。
### Oct 9. 2023 / 2023/10/9
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
- Please install Onnx and other required packages.
1. Uninstall TensorFlow.
1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
1. `pip install protobuf==3.20.3` This is required for Onnx.
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
- [OFT](https://oft.wyliu.com/) is now supported.
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
- Other bug fixes and improvements.
- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。
- Onnxを使用する場合は、`--onnx` オプションを指定してください。
- Onnx とその他の必要なパッケージをインストールしてください。
1. TensorFlow をアンインストールしてください。
1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要。
1. `pip install protobuf==3.20.3` Onnxのために必要。
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0`
- `tag_images_by_wd_14_tagger.py``--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。
- [OFT](https://oft.wyliu.com/) をサポートしました。
- `sdxl_train_network.py``--network_module``networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。
- `sdxl_gen_img.py` でも同様に OFT を指定できます。
- OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。
- 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。
- その他のバグ修正と改善。
### Oct 1. 2023 / 2023/10/1
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.

View File

@@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):
def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
@@ -65,6 +68,7 @@ def main(args):
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
@@ -81,7 +85,7 @@ def main(args):
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
# curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)

View File

@@ -1,17 +1,15 @@
import argparse
import csv
import glob
import os
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
# from wd14 tagger
@@ -20,6 +18,7 @@ IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
FILES_ONNX = ["model.onnx"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
@@ -81,7 +80,10 @@ def main(args):
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
files = FILES
if args.onnx:
files += FILES_ONNX
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
@@ -96,7 +98,46 @@ def main(args):
print("using existing wd14 tagger model")
# 画像を読み込む
model = load_model(args.model_dir)
if args.onnx:
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
)
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
@@ -124,8 +165,14 @@ def main(args):
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False)
probs = probs.numpy()
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
@@ -165,9 +212,27 @@ def main(args):
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
tag_text = ", ".join(combined_tags)
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
@@ -283,12 +348,15 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
# スペルミスしていたオプションを復元する

View File

@@ -65,10 +65,13 @@ import re
import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -954,7 +957,7 @@ class PipelineLike:
text_emb_last = torch.stack(text_emb_last)
else:
text_emb_last = text_embeddings
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
@@ -2363,12 +2366,19 @@ def main(args):
network_default_muls = []
network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {}
if args.network_args and i < len(args.network_args):
@@ -2379,31 +2389,32 @@ def main(args):
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
@@ -2417,6 +2428,7 @@ def main(args):
network.backup_weights()
networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
@@ -2712,9 +2724,18 @@ def main(args):
size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
@@ -3367,10 +3388,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,

View File

@@ -96,6 +96,7 @@ try:
except:
pass
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
@@ -103,6 +104,14 @@ try:
except:
pass
# JPEG-XL on Windows
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),

430
networks/oft.py Normal file
View File

@@ -0,0 +1,430 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
):
"""
dim -> num blocks
alpha -> constraint
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if "Linear" in org_module.__class__.__name__:
out_dim = org_module.out_features
elif "Conv" in org_module.__class__.__name__:
out_dim = org_module.out_channels
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.out_dim = out_dim
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
R = torch.block_diag(*block_R_weighted)
return R
def forward(self, x, scale=None):
x = self.org_forward(x)
if self.multiplier == 0.0:
return x
R = self.get_weight().to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)
x = x.permute(0, 3, 1, 2)
else:
x = torch.matmul(x, R)
return x
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None, sign=1):
R = self.get_weight(multiplier) * sign
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
if enable_conv is not None:
enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
enable_conv=enable_conv,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
has_conv2d = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and param.dim() == 4:
has_conv2d = True
if all_linear is None:
if param.dim() == 3 and "attn" not in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None:
break
if has_conv2d is None:
has_conv2d = False
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
enable_conv=has_conv2d,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
enable_conv: Optional[bool] = False,
module_class: Type[object] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
print(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
is_conv2d = "Conv2d" in child_module.__class__.__name__
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# print(oft_name)
oft = module_class(
oft_name,
child_module,
self.multiplier,
dim,
alpha,
)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
if enable_conv:
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
print(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# print num of params
num_params = 0
for p in params:
num_params += p.numel()
print(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -19,8 +19,14 @@ huggingface-hub==0.15.1
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# for kohya_ss library

View File

@@ -17,10 +17,13 @@ import re
import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -1534,12 +1537,20 @@ def main(args):
network_default_muls = []
network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
print(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {}
if args.network_args and i < len(args.network_args):
@@ -1550,31 +1561,32 @@ def main(args):
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
@@ -1588,6 +1600,7 @@ def main(args):
network.backup_weights()
networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
@@ -1864,9 +1877,18 @@ def main(args):
size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
@@ -2615,10 +2637,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,

View File

@@ -283,7 +283,10 @@ class NetworkTrainer:
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
# LyCORIS will work with this...
if "dropout" not in net_kwargs:
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout
network = network_module.create_network(
1.0,
args.network_dim,