mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix: metadata dataset degradation and make it work (#2186)
* fix: support dataset with metadata * feat: support another tagger model * fix: improve handling of image size and caption/tag processing in FineTuningDataset * fix: enhance metadata loading to support JSONL format in FineTuningDataset * feat: enhance image loading and processing in ImageLoadingPrepDataset with batch support and output options * fix: improve image path handling and memory management in dataset classes * Update finetune/tag_images_by_wd14_tagger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: add return type annotation for process_tag_replacement function and ensure tags are returned * feat: add artist category threshold for tagging * doc: add comment for clarification --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github.
|
||||
Using onnx for inference is recommended. Please install onnx with the following command:
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details.
|
||||
|
||||
The model weights will be automatically downloaded from Hugging Face.
|
||||
|
||||
# Usage
|
||||
@@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
|
||||
|
||||
# Options
|
||||
|
||||
All options can be checked with `python tag_images_by_wd14_tagger.py --help`.
|
||||
|
||||
## General Options
|
||||
|
||||
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。
|
||||
|
||||
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
|
||||
|
||||
# 使い方
|
||||
@@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
|
||||
|
||||
# オプション
|
||||
|
||||
全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。
|
||||
|
||||
## 一般オプション
|
||||
|
||||
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -29,8 +31,22 @@ SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
|
||||
TAG_JSON_FILE = "tag_mapping.json"
|
||||
|
||||
|
||||
def preprocess_image(image: Image.Image) -> np.ndarray:
|
||||
# If image has transparency, convert to RGBA. If not, convert to RGB
|
||||
if image.mode in ("RGBA", "LA") or "transparency" in image.info:
|
||||
image = image.convert("RGBA")
|
||||
elif image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# If image is RGBA, combine with white background
|
||||
if image.mode == "RGBA":
|
||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
||||
background.paste(image, mask=image.split()[3]) # Use alpha channel as mask
|
||||
image = background
|
||||
|
||||
def preprocess_image(image):
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
|
||||
@@ -49,67 +65,103 @@ def preprocess_image(image):
|
||||
|
||||
|
||||
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths: list[str], batch_size: int):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = str(self.images[idx])
|
||||
def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]:
|
||||
image_index_start = batch_index * self.batch_size
|
||||
image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths))
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
batch_image_paths = []
|
||||
images = []
|
||||
image_sizes = []
|
||||
for idx in range(image_index_start, image_index_end):
|
||||
img_path = str(self.image_paths[idx])
|
||||
|
||||
return (image, img_path)
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
|
||||
batch_image_paths.append(img_path)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
|
||||
images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3))
|
||||
return batch_image_paths, images, image_sizes
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
def collate_fn_no_op(batch):
|
||||
"""Collate function that does nothing and returns the batch as is."""
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# model location is model_dir + repo_id
|
||||
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
|
||||
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
|
||||
# given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir"
|
||||
# so we split it to "namespace/reponame" and "subdir"
|
||||
tokens = args.repo_id.split("/")
|
||||
|
||||
if len(tokens) > 2:
|
||||
repo_id = "/".join(tokens[:2])
|
||||
subdir = "/".join(tokens[2:])
|
||||
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir)
|
||||
onnx_model_name = "model_optimized.onnx"
|
||||
default_format = False
|
||||
else:
|
||||
repo_id = args.repo_id
|
||||
subdir = None
|
||||
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"))
|
||||
onnx_model_name = "model.onnx"
|
||||
default_format = True
|
||||
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
|
||||
if not os.path.exists(model_location) or args.force_download:
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
|
||||
if subdir is None:
|
||||
# SmilingWolf structure
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
# another structure
|
||||
files = [onnx_model_name, "tag_mapping.json"]
|
||||
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=file,
|
||||
subfolder=subdir,
|
||||
local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified
|
||||
force_download=True,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
@@ -118,7 +170,7 @@ def main(args):
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{model_location}/model.onnx"
|
||||
onnx_path = os.path.join(model_location, onnx_model_name)
|
||||
logger.info("Running wd14 tagger with onnx")
|
||||
logger.info(f"loading onnx model: {onnx_path}")
|
||||
|
||||
@@ -150,39 +202,30 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
|
||||
provider_options=[{"device_type": "GPU", "precision": "FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
),
|
||||
providers = (
|
||||
["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else (
|
||||
["ROCMExecutionProvider"]
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
)
|
||||
logger.info(f"Using onnxruntime providers: {providers}")
|
||||
ort_sess = ort.InferenceSession(onnx_path, providers=providers)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{model_location}")
|
||||
|
||||
# We read the CSV file manually to avoid adding dependencies.
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
# preprocess tags in advance
|
||||
if args.character_tag_expand:
|
||||
for i, tag in enumerate(character_tags):
|
||||
def expand_character_tags(char_tags):
|
||||
for i, tag in enumerate(char_tags):
|
||||
if tag.endswith(")"):
|
||||
# chara_name_(series) -> chara_name, series
|
||||
# chara_name_(costume)_(series) -> chara_name_(costume), series
|
||||
@@ -191,35 +234,95 @@ def main(args):
|
||||
if character_tag.endswith("_"):
|
||||
character_tag = character_tag[:-1]
|
||||
series_tag = tags[-1].replace(")", "")
|
||||
character_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
char_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
|
||||
if args.remove_underscore:
|
||||
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
|
||||
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
|
||||
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
|
||||
def remove_underscore(tags):
|
||||
return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags]
|
||||
|
||||
if args.tag_replacement is not None:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
|
||||
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
def process_tag_replacement(tags: list[str], tag_replacements_arg: str) -> list[str]:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;,
|
||||
# so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means
|
||||
# `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD`
|
||||
escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
for tag_replacements_arg in tag_replacements:
|
||||
tags = tag_replacements_arg.split(",") # source, target
|
||||
assert (
|
||||
len(tags) == 2
|
||||
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
|
||||
if source in general_tags:
|
||||
general_tags[general_tags.index(source)] = target
|
||||
elif source in character_tags:
|
||||
character_tags[character_tags.index(source)] = target
|
||||
elif source in rating_tags:
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
if source in tags:
|
||||
tags[tags.index(source)] = target
|
||||
|
||||
return tags
|
||||
|
||||
if default_format:
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
if args.character_tag_expand:
|
||||
expand_character_tags(character_tags)
|
||||
if args.remove_underscore:
|
||||
rating_tags = remove_underscore(rating_tags)
|
||||
character_tags = remove_underscore(character_tags)
|
||||
general_tags = remove_underscore(general_tags)
|
||||
if args.tag_replacement is not None:
|
||||
process_tag_replacement(rating_tags, args.tag_replacement)
|
||||
process_tag_replacement(general_tags, args.tag_replacement)
|
||||
process_tag_replacement(character_tags, args.tag_replacement)
|
||||
else:
|
||||
with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f:
|
||||
tag_mapping = json.load(f)
|
||||
|
||||
rating_tags = []
|
||||
general_tags = []
|
||||
character_tags = []
|
||||
|
||||
tag_id_to_tag_mapping = {}
|
||||
tag_id_to_category_mapping = {}
|
||||
for tag_id, tag_info in tag_mapping.items():
|
||||
tag = tag_info["tag"]
|
||||
category = tag_info["category"]
|
||||
assert category in [
|
||||
"Rating",
|
||||
"General",
|
||||
"Character",
|
||||
"Copyright",
|
||||
"Meta",
|
||||
"Model",
|
||||
"Quality",
|
||||
"Artist",
|
||||
], f"unexpected category: {category}"
|
||||
|
||||
if args.remove_underscore:
|
||||
tag = remove_underscore([tag])[0]
|
||||
if args.tag_replacement is not None:
|
||||
tag = process_tag_replacement([tag], args.tag_replacement)[0]
|
||||
if category == "Character" and args.character_tag_expand:
|
||||
tag_list = [tag]
|
||||
expand_character_tags(tag_list)
|
||||
tag = tag_list[0]
|
||||
|
||||
tag_id_to_tag_mapping[int(tag_id)] = tag
|
||||
tag_id_to_category_mapping[int(tag_id)] = category
|
||||
|
||||
# 画像を読み込む
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
image_paths = [str(ip) for ip in image_paths]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -232,59 +335,150 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[dict[str, dict]]:
|
||||
nonlocal args, default_format, model, ort_sess, input_name, tag_freq
|
||||
|
||||
imgs = path_imgs[1]
|
||||
result = {}
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
if not default_format:
|
||||
imgs = imgs.transpose(0, 3, 1, 2) # to NCHW
|
||||
imgs = imgs / 127.5 - 1.0
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(imgs)] # remove padding
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
general_tag_text = ""
|
||||
other_tag_text = ""
|
||||
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
if default_format:
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
else:
|
||||
# apply sigmoid to probabilities
|
||||
prob = 1 / (1 + np.exp(-prob))
|
||||
|
||||
rating_max_prob = -1
|
||||
rating_tag = None
|
||||
quality_max_prob = -1
|
||||
quality_tag = None
|
||||
character_tags = []
|
||||
|
||||
min_thres = min(
|
||||
args.thresh,
|
||||
args.general_threshold,
|
||||
args.character_threshold,
|
||||
args.copyright_threshold,
|
||||
args.meta_threshold,
|
||||
args.model_threshold,
|
||||
args.artist_threshold,
|
||||
)
|
||||
prob_indices = np.where(prob >= min_thres)[0]
|
||||
# for i, p in enumerate(prob):
|
||||
for i in prob_indices:
|
||||
if i not in tag_id_to_tag_mapping:
|
||||
continue
|
||||
p = prob[i]
|
||||
|
||||
tag_name = tag_id_to_tag_mapping[i]
|
||||
category = tag_id_to_category_mapping[i]
|
||||
if tag_name in undesired_tags:
|
||||
continue
|
||||
|
||||
if category == "Rating":
|
||||
if p > rating_max_prob:
|
||||
rating_max_prob = p
|
||||
rating_tag = tag_name
|
||||
rating_tag_text = tag_name
|
||||
continue
|
||||
elif category == "Quality":
|
||||
if p > quality_max_prob:
|
||||
quality_max_prob = p
|
||||
quality_tag = tag_name
|
||||
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
|
||||
other_tag_text += caption_separator + tag_name
|
||||
continue
|
||||
|
||||
if category == "General" and p >= args.general_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
combined_tags.append((tag_name, p))
|
||||
elif category == "Character" and p >= args.character_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
if args.character_tags_first: # we separate character tags
|
||||
character_tags.append((tag_name, p))
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
combined_tags.append((tag_name, p))
|
||||
elif (
|
||||
(category == "Copyright" and p >= args.copyright_threshold)
|
||||
or (category == "Meta" and p >= args.meta_threshold)
|
||||
or (category == "Model" and p >= args.model_threshold)
|
||||
or (category == "Artist" and p >= args.artist_threshold)
|
||||
):
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
other_tag_text += f"{caption_separator}{tag_name} ({category})"
|
||||
combined_tags.append((tag_name, p))
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
# sort by probability
|
||||
combined_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
if character_tags:
|
||||
character_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
combined_tags = character_tags + combined_tags
|
||||
combined_tags = [t[0] for t in combined_tags] # remove probability
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
if quality_tag is not None:
|
||||
if args.use_quality_tags_as_last_tag:
|
||||
combined_tags.append(quality_tag)
|
||||
elif args.use_quality_tags:
|
||||
combined_tags.insert(0, quality_tag)
|
||||
if rating_tag is not None:
|
||||
if args.use_rating_tags_as_last_tag:
|
||||
combined_tags.append(rating_tag)
|
||||
elif args.use_rating_tags:
|
||||
combined_tags.insert(0, rating_tag)
|
||||
|
||||
# 一番最初に置くタグを指定する
|
||||
# Always put some tags at the beginning
|
||||
@@ -299,6 +493,8 @@ def main(args):
|
||||
general_tag_text = general_tag_text[len(caption_separator) :]
|
||||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[len(caption_separator) :]
|
||||
if len(other_tag_text) > 0:
|
||||
other_tag_text = other_tag_text[len(caption_separator) :]
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
@@ -320,55 +516,79 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if not args.output_path:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
entry = {"tags": tag_text, "image_size": list(image_size)}
|
||||
result[image_path] = entry
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if other_tag_text:
|
||||
logger.info(f"\tOther tags: {other_tag_text}")
|
||||
|
||||
return result
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
dataset = ImageLoadingPrepDataset(image_paths, args.batch_size)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
collate_fn=collate_fn_no_op,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# data = [[(ip, None, None)] for ip in image_paths]
|
||||
data = [[]]
|
||||
for ip in image_paths:
|
||||
if len(data[-1]) >= args.batch_size:
|
||||
data.append([])
|
||||
data[-1].append((ip, None, None))
|
||||
|
||||
b_imgs = []
|
||||
results = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
if data_entry is None or len(data_entry) == 0:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
if data_entry[0][1] is None:
|
||||
# No preloaded image, need to load
|
||||
images = []
|
||||
image_sizes = []
|
||||
for image_path, _, _ in data_entry:
|
||||
image = Image.open(image_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes)
|
||||
else:
|
||||
b_imgs = data_entry[0]
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
r = run_batch(b_imgs)
|
||||
if args.output_path and r is not None:
|
||||
results.update(r)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.output_path:
|
||||
if args.output_path.endswith(".jsonl"):
|
||||
# optional JSONL metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
for image_path, entry in results.items():
|
||||
f.write(
|
||||
json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n"
|
||||
)
|
||||
else:
|
||||
# standard JSON metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"captions saved to {args.output_path}")
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -381,9 +601,7 @@ def main(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -401,15 +619,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパス(json形式)。このオプションが設定されている場合、キャプションはこのファイルに保存されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
@@ -432,7 +654,36 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--character_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags"
|
||||
" / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags"
|
||||
" / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags"
|
||||
" / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--copyright_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags"
|
||||
" / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--artist_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for artist category, same as --thresh if omitted. set above 1 to disable artist tags"
|
||||
" / artistカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとartistタグを無効化できる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
|
||||
@@ -442,9 +693,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
@@ -454,20 +703,34 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
"--use_rating_tags",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
"--use_rating_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
"--use_quality_tags",
|
||||
action="store_true",
|
||||
help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_quality_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first",
|
||||
action="store_true",
|
||||
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
@@ -512,5 +775,13 @@ if __name__ == "__main__":
|
||||
args.general_threshold = args.thresh
|
||||
if args.character_threshold is None:
|
||||
args.character_threshold = args.thresh
|
||||
if args.meta_threshold is None:
|
||||
args.meta_threshold = args.thresh
|
||||
if args.model_threshold is None:
|
||||
args.model_threshold = args.thresh
|
||||
if args.copyright_threshold is None:
|
||||
args.copyright_threshold = args.thresh
|
||||
if args.artist_threshold is None:
|
||||
args.artist_threshold = args.thresh
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -1131,7 +1131,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
other is not None
|
||||
and self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
@@ -1193,6 +1194,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
@@ -1205,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
@@ -1768,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tensors = [converter(x) for x in tensors]
|
||||
if tensors[0].ndim == 1:
|
||||
# input_ids or mask
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
|
||||
else:
|
||||
# text encoder outputs
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
|
||||
return result
|
||||
|
||||
# set example
|
||||
@@ -2202,6 +2201,23 @@ class FineTuningDataset(BaseDataset):
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
@@ -2221,9 +2237,25 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
if subset.metadata_file.endswith(".jsonl"):
|
||||
logger.info(f"loading existing JSOL metadata: {subset.metadata_file}")
|
||||
# optional JSONL format
|
||||
# {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1", "image_size": [width, height]}
|
||||
metadata = {}
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line_md = json.loads(line)
|
||||
image_md = {"caption": line_md.get("caption", "")}
|
||||
if "image_size" in line_md:
|
||||
image_md["image_size"] = line_md["image_size"]
|
||||
if "tags" in line_md:
|
||||
image_md["tags"] = line_md["tags"]
|
||||
metadata[line_md["image_path"]] = image_md
|
||||
else:
|
||||
# standard JSON format
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
|
||||
@@ -2233,65 +2265,101 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
# Add full path for image
|
||||
image_dirs = set()
|
||||
if subset.image_dir is not None:
|
||||
image_dirs.add(subset.image_dir)
|
||||
for image_key in metadata.keys():
|
||||
if not os.path.isabs(image_key):
|
||||
assert (
|
||||
subset.image_dir is not None
|
||||
), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}"
|
||||
abs_path = os.path.join(subset.image_dir, image_key)
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
abs_path = image_key
|
||||
image_dirs.add(os.path.dirname(abs_path))
|
||||
metadata[image_key]["abs_path"] = abs_path
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
# Enumerate existing npz files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
npz_paths = []
|
||||
for image_dir in image_dirs:
|
||||
npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix)))
|
||||
npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
# Match image filename longer to shorter because some images share same prefix
|
||||
image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True)
|
||||
|
||||
# Collect tags and sizes
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
image_size = img_md.get("image_size")
|
||||
abs_path = img_md.get("abs_path")
|
||||
|
||||
# search npz if image_size is not given
|
||||
npz_path = None
|
||||
if image_size is None:
|
||||
image_without_ext = os.path.splitext(image_key)[0]
|
||||
for candidate in npz_paths:
|
||||
if candidate.startswith(image_without_ext):
|
||||
npz_path = candidate
|
||||
break
|
||||
if npz_path is not None:
|
||||
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
|
||||
abs_path = npz_path
|
||||
|
||||
if caption is None:
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
caption = ""
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
# tags must be single line (split by caption separator)
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
if tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
tags_list.append(tags)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
if len(caption) > 0:
|
||||
caption = caption + subset.caption_separator
|
||||
caption = caption + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
if image_size is not None:
|
||||
image_info.image_size = tuple(image_size) # width, height
|
||||
size_set_from_metadata += 1
|
||||
elif npz_path is not None:
|
||||
# get image size from npz filename
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path)
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
logger.info(
|
||||
f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}"
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2299,117 +2367,6 @@ class FineTuningDataset(BaseDataset):
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
use_npz_latents = False
|
||||
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
resos = set()
|
||||
for image_info in self.image_data.values():
|
||||
if image_info.image_size is None:
|
||||
sizes = None # not calculated
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
|
||||
)
|
||||
|
||||
assert (
|
||||
resolution is not None
|
||||
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
|
||||
assert (
|
||||
not bucket_no_upscale
|
||||
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user