From 70fe7e18bea63bb2ddc3c8dfdb3a2367d55cb348 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Oct 2023 20:31:10 +0800 Subject: [PATCH 1/3] add onnx to wd14 tagger --- finetune/tag_images_by_wd14_tagger.py | 55 +++++++++++++++++++++------ requirements.txt | 4 +- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 91e4f573..816aaddb 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -2,16 +2,15 @@ import argparse import csv import glob import os - -from PIL import Image -import cv2 -from tqdm import tqdm -import numpy as np -from tensorflow.keras.models import load_model -from huggingface_hub import hf_hub_download -import torch from pathlib import Path +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from tqdm import tqdm + import library.train_util as train_util # from wd14 tagger @@ -81,6 +80,8 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + if args.onnx: + FILES.append("model.onnx") for file in FILES: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: @@ -96,7 +97,35 @@ def main(args): print("using existing wd14 tagger model") # 画像を読み込む - model = load_model(args.model_dir) + if args.onnx: + import onnx + import onnxruntime as ort + + onnx_path = f"{args.model_dir}/model.onnx" + print("Running wd14 tagger with onnx") + print(f"loading onnx model: {onnx_path}") + model = onnx.load(onnx_path) + input_name = model.graph.input[0].name + try: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value + except: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param + if args.batch_size != batch_size and type(batch_size) != str: + # some rebatch model may use 'N' as dynamic axes + print( + f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" + ) + args.batch_size = batch_size + ort_sess = ort.InferenceSession( + model.SerializeToString(), + providers=["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"], + ) + else: + from tensorflow.keras.models import load_model + + model = load_model(f"{args.model_dir}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ @@ -124,8 +153,11 @@ def main(args): def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) - probs = model(imgs, training=False) - probs = probs.numpy() + if args.onnx: + probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy + else: + probs = model(imgs, training=False) + probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): # 最初の4つはratingなので無視する @@ -283,6 +315,7 @@ def setup_parser() -> argparse.ArgumentParser: help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference") return parser diff --git a/requirements.txt b/requirements.txt index 4ca393f5..fa6005ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,8 +19,10 @@ huggingface-hub==0.15.1 # requests==2.28.2 # timm==0.6.12 # fairscale==0.4.13 -# for WD14 captioning +# for WD14 captioning (tensroflow or onnx) # tensorflow==2.10.1 +# onnx==1.14.1 +# onnxruntime==1.16.0 # open clip for SDXL open-clip-torch==2.20.0 # for kohya_ss library From b8b84021e54b34ed04800e21a18fc67e6e9ce1c1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Oct 2023 20:49:03 +0800 Subject: [PATCH 2/3] fix a typo --- finetune/tag_images_by_wd14_tagger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 816aaddb..6b33af51 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -117,7 +117,7 @@ def main(args): ) args.batch_size = batch_size ort_sess = ort.InferenceSession( - model.SerializeToString(), + onnx_path, providers=["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"], @@ -154,7 +154,7 @@ def main(args): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy + probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy else: probs = model(imgs, training=False) probs = probs.numpy() From d6f458fcb3cda470486a9d0ea3a2dad0c72b46db Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Oct 2023 23:51:18 +0800 Subject: [PATCH 3/3] fix dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index fa6005ac..75de48cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ huggingface-hub==0.15.1 # for WD14 captioning (tensroflow or onnx) # tensorflow==2.10.1 # onnx==1.14.1 +# onnxruntime-gpu==1.16.0 # onnxruntime==1.16.0 # open clip for SDXL open-clip-torch==2.20.0