mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Compare commits
10 Commits
7a979b1527
...
new_cache
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3a85060ef | ||
|
|
f2322a23e2 | ||
|
|
70423ec61d | ||
|
|
28e9352cc5 | ||
|
|
b72b9eaf11 | ||
|
|
744cf03136 | ||
|
|
2238b94e7b | ||
|
|
665c04e649 | ||
|
|
3677094256 | ||
|
|
bdac55ebbc |
@@ -178,7 +178,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
232
finetune/caption_images_by_florence2.py
Normal file
232
finetune/caption_images_by_florence2.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# add caption to images by Florence-2
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
from library import device_utils, train_util, dataset_metadata_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import tagger_utils
|
||||
|
||||
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.load_archive == (
|
||||
args.metadata is not None
|
||||
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
|
||||
|
||||
device = args.device if args.device is not None else device_utils.get_preferred_device()
|
||||
if type(device) is str:
|
||||
device = torch.device(device)
|
||||
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
|
||||
logger.info(f"device: {device}, dtype: {torch_dtype}")
|
||||
|
||||
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
|
||||
|
||||
support_flash_attn = False
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
support_flash_attn = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if support_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
else:
|
||||
logger.info(
|
||||
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
|
||||
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from unittest.mock import patch
|
||||
|
||||
def fixed_get_imports(filename) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
# workaround for unnecessary flash_attn requirement
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
|
||||
model.eval()
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
||||
|
||||
# 画像を読み込む
|
||||
if not args.load_archive:
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
else:
|
||||
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
|
||||
os.path.join(args.train_data_dir, "*.tar")
|
||||
)
|
||||
image_paths = [Path(archive_file) for archive_file in archive_files]
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
# define preprocess_image function
|
||||
def preprocess_image(image: Image.Image):
|
||||
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
|
||||
return inputs
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
# Loader returns: list of (image_path, processed_image_or_something, image_size)
|
||||
if args.load_archive:
|
||||
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
|
||||
else:
|
||||
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
|
||||
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
|
||||
|
||||
def run_batch(
|
||||
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
|
||||
images_metadata: Optional[dict[str, Any]],
|
||||
caption_index: Optional[int] = None,
|
||||
):
|
||||
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
|
||||
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
|
||||
|
||||
if args.debug:
|
||||
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
num_beams=args.num_beams,
|
||||
)
|
||||
if args.debug:
|
||||
logger.info(f"generate done: {generated_ids.shape}")
|
||||
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
|
||||
if args.debug:
|
||||
logger.info(f"decode done: {len(generated_texts)}")
|
||||
|
||||
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
|
||||
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
|
||||
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
|
||||
|
||||
caption_text = caption_text.strip().replace("<pad>", "")
|
||||
original_caption_text = caption_text
|
||||
|
||||
if args.remove_mood:
|
||||
p = caption_text.find("The overall ")
|
||||
if p != -1:
|
||||
caption_text = caption_text[:p].strip()
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
if images_metadata is None:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(caption_text + "\n")
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is None:
|
||||
image_md = {"image_size": list(image_size)}
|
||||
images_metadata[image_path] = image_md
|
||||
if "caption" not in image_md:
|
||||
image_md["caption"] = []
|
||||
if caption_index is None:
|
||||
image_md["caption"].append(caption_text)
|
||||
else:
|
||||
while len(image_md["caption"]) <= caption_index:
|
||||
image_md["caption"].append("")
|
||||
image_md["caption"][caption_index] = caption_text
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tCaption: {caption_text}")
|
||||
if args.remove_mood and original_caption_text != caption_text:
|
||||
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
|
||||
|
||||
for data_entry in tqdm(loader, smoothing=0.0):
|
||||
b_imgs = data_entry
|
||||
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs, images_metadata, args.caption_index)
|
||||
|
||||
if args.metadata is not None:
|
||||
logger.info(f"saving metadata file: {args.metadata}")
|
||||
with open(args.metadata, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
|
||||
)
|
||||
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
|
||||
parser.add_argument(
|
||||
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=3,
|
||||
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
|
||||
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
tagger_utils.add_archive_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -180,7 +180,7 @@ def main(args):
|
||||
|
||||
# バッチへ追加
|
||||
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
|
||||
image_info.latents_npz = npz_file_name
|
||||
image_info.latents_cache_path = npz_file_name
|
||||
image_info.bucket_reso = reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.image = image
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -10,14 +13,18 @@ from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library import dataset_metadata_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import pil_resize
|
||||
import tagger_utils
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
@@ -63,13 +70,14 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (image, img_path)
|
||||
return (image, img_path, size)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
@@ -83,6 +91,10 @@ def collate_fn_remove_corrupted(batch):
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.load_archive == (
|
||||
args.metadata is not None
|
||||
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
|
||||
|
||||
# model location is model_dir + repo_id
|
||||
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
|
||||
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
|
||||
@@ -149,15 +161,19 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU_FP32"}],
|
||||
provider_options=[{"device_type": "GPU_FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else (
|
||||
["ROCMExecutionProvider"]
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -203,7 +219,9 @@ def main(args):
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
assert (
|
||||
len(tags) == 2
|
||||
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
@@ -216,9 +234,15 @@ def main(args):
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
|
||||
# 画像を読み込む
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
if not args.load_archive:
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
else:
|
||||
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
|
||||
os.path.join(args.train_data_dir, "*.tar")
|
||||
)
|
||||
image_paths = [Path(archive_file) for archive_file in archive_files]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -231,19 +255,23 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(
|
||||
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
|
||||
images_metadata: Optional[dict[str, Any]],
|
||||
tags_index: Optional[int] = None,
|
||||
):
|
||||
imgs = np.array([im for _, im, _ in list_of_path_img_size])
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(list_of_path_img_size)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
@@ -265,7 +293,7 @@ def main(args):
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
@@ -281,7 +309,7 @@ def main(args):
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
|
||||
@@ -304,12 +332,24 @@ def main(args):
|
||||
tag_text = caption_separator.join(combined_tags)
|
||||
|
||||
if args.append_tags:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
existing_content = None
|
||||
if images_metadata is None:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is not None:
|
||||
tags = image_md.get("tags", None)
|
||||
if tags is not None:
|
||||
if tags_index is None and len(tags) > 0:
|
||||
existing_content = tags[-1]
|
||||
elif tags_index is not None and tags_index < len(tags):
|
||||
existing_content = tags[tags_index]
|
||||
|
||||
if existing_content is not None:
|
||||
# Split the content into tags and store them in a list
|
||||
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
|
||||
|
||||
@@ -319,19 +359,46 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if images_metadata is None:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is None:
|
||||
image_md = {"image_size": list(image_size)}
|
||||
images_metadata[image_path] = image_md
|
||||
if "tags" not in image_md:
|
||||
image_md["tags"] = []
|
||||
if tags_index is None:
|
||||
image_md["tags"].append(tag_text)
|
||||
else:
|
||||
while len(image_md["tags"]) <= tags_index:
|
||||
image_md["tags"].append("")
|
||||
image_md["tags"][tags_index] = tag_text
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
use_loader = False
|
||||
if args.load_archive:
|
||||
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
|
||||
use_loader = True
|
||||
elif args.max_data_loader_n_workers is not None:
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
@@ -339,35 +406,37 @@ def main(args):
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
use_loader = True
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# make batch of image paths
|
||||
loader = []
|
||||
for i in range(0, len(image_paths), args.batch_size):
|
||||
loader.append(image_paths[i : i + args.batch_size])
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
for data_entry in tqdm(loader, smoothing=0.0):
|
||||
if use_loader:
|
||||
b_imgs = data_entry
|
||||
else:
|
||||
b_imgs = []
|
||||
for image_path in data_entry:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
b_imgs.append((image_path, image, size))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs, images_metadata, args.tags_index)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.metadata is not None:
|
||||
logger.info(f"saving metadata file: {args.metadata}")
|
||||
with open(args.metadata, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -380,9 +449,7 @@ def main(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -400,9 +467,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
@@ -441,9 +506,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
@@ -453,20 +516,24 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
"--use_rating_tags",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
"--use_rating_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
"--character_tags_first",
|
||||
action="store_true",
|
||||
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
@@ -495,6 +562,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tags_index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
|
||||
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
|
||||
)
|
||||
tagger_utils.add_archive_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
150
finetune/tagger_utils.py
Normal file
150
finetune/tagger_utils.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Union
|
||||
import zipfile
|
||||
import tarfile
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import dataset_metadata_utils, train_util
|
||||
|
||||
|
||||
class ArchiveImageLoader:
|
||||
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
|
||||
self.archive_paths = archive_paths
|
||||
self.batch_size = batch_size
|
||||
self.preprocess = preprocess
|
||||
self.debug = debug
|
||||
self.current_archive = None
|
||||
self.archive_index = 0
|
||||
self.image_index = 0
|
||||
self.files = None
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
images = []
|
||||
while len(images) < self.batch_size:
|
||||
if self.current_archive is None:
|
||||
if self.archive_index >= len(self.archive_paths):
|
||||
if len(images) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
break # return the remaining images
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
|
||||
|
||||
current_archive_path = self.archive_paths[self.archive_index]
|
||||
if current_archive_path.endswith(".zip"):
|
||||
self.current_archive = zipfile.ZipFile(current_archive_path)
|
||||
self.files = self.current_archive.namelist()
|
||||
elif current_archive_path.endswith(".tar"):
|
||||
self.current_archive = tarfile.open(current_archive_path, "r")
|
||||
self.files = self.current_archive.getnames()
|
||||
else:
|
||||
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
|
||||
|
||||
self.image_index = 0
|
||||
|
||||
# filter by image extensions
|
||||
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"found {len(self.files)} images in the archive")
|
||||
|
||||
new_images = []
|
||||
while len(images) + len(new_images) < self.batch_size:
|
||||
if self.image_index >= len(self.files):
|
||||
break
|
||||
|
||||
file = self.files[self.image_index]
|
||||
archive_and_image_path = (
|
||||
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
|
||||
)
|
||||
self.image_index += 1
|
||||
|
||||
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
|
||||
with archive.open(file) as f:
|
||||
image = Image.open(f).convert("RGB")
|
||||
size = image.size
|
||||
image = self.preprocess(image)
|
||||
return image, size
|
||||
|
||||
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
|
||||
|
||||
# wait for all new_images to load to close the archive
|
||||
new_images = [(image_path, future.result()) for image_path, future in new_images]
|
||||
|
||||
if self.image_index >= len(self.files):
|
||||
self.current_archive.close()
|
||||
self.current_archive = None
|
||||
self.archive_index += 1
|
||||
|
||||
images.extend(new_images)
|
||||
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
class ImageLoader:
|
||||
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
self.preprocess = preprocess
|
||||
self.debug = debug
|
||||
self.image_index = 0
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.image_index >= len(self.image_paths):
|
||||
raise StopIteration
|
||||
|
||||
images = []
|
||||
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
|
||||
|
||||
def load_image(file):
|
||||
image = Image.open(file).convert("RGB")
|
||||
size = image.size
|
||||
image = self.preprocess(image)
|
||||
return image, size
|
||||
|
||||
image_path = self.image_paths[self.image_index]
|
||||
images.append((image_path, self.executor.submit(load_image, image_path)))
|
||||
self.image_index += 1
|
||||
|
||||
images = [(image_path, future.result()) for image_path, future in images]
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
def add_archive_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_archive",
|
||||
action="store_true",
|
||||
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
|
||||
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
|
||||
)
|
||||
@@ -152,15 +152,20 @@ def train(args):
|
||||
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
)
|
||||
t5xxl_max_token_length = (
|
||||
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
|
||||
)
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -199,7 +204,7 @@ def train(args):
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -237,7 +242,12 @@ def train(args):
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
False,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
|
||||
@@ -11,16 +11,6 @@ from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
flux_models,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_flux,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -28,6 +18,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
||||
import train_network
|
||||
|
||||
|
||||
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
@@ -185,13 +178,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
|
||||
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
58
library/dataset_metadata_utils.py
Normal file
58
library/dataset_metadata_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
METADATA_VERSION = [1, 0, 0]
|
||||
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
|
||||
|
||||
ARCHIVE_PATH_SEPARATOR = "////"
|
||||
|
||||
|
||||
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
|
||||
if os.path.exists(metadata_file):
|
||||
logger.info(f"loading metadata file: {metadata_file}")
|
||||
with open(metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
if not create_new:
|
||||
return None
|
||||
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
|
||||
metadata = {"format_version": VERSION_STRING, "images": {}}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def is_archive_path(archive_and_image_path: str) -> bool:
|
||||
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
|
||||
|
||||
|
||||
def get_inner_path(archive_and_image_path: str) -> str:
|
||||
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
|
||||
|
||||
|
||||
def get_archive_digest(archive_and_image_path: str) -> str:
|
||||
"""
|
||||
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
|
||||
"""
|
||||
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
|
||||
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"
|
||||
@@ -2,16 +2,14 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from safetensors.torch import safe_open, save_file
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -19,6 +17,81 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import dataset_metadata_utils, utils
|
||||
|
||||
|
||||
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
|
||||
if dtype is None:
|
||||
# all dtypes are acceptable
|
||||
return get_available_dtypes()
|
||||
|
||||
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
||||
compatible_dtypes = [torch.float32]
|
||||
if dtype.itemsize == 1: # fp8
|
||||
compatible_dtypes.append(torch.bfloat16)
|
||||
compatible_dtypes.append(torch.float16)
|
||||
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
|
||||
return compatible_dtypes
|
||||
|
||||
|
||||
def get_available_dtypes() -> List[torch.dtype]:
|
||||
"""
|
||||
Returns the list of available dtypes for latents caching. Higher precision is preferred.
|
||||
"""
|
||||
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
|
||||
|
||||
|
||||
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
|
||||
"""
|
||||
Removes lower precision values from tensor_dict.
|
||||
"""
|
||||
available_dtypes = get_available_dtypes()
|
||||
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
|
||||
|
||||
for key_without_dtype in keys_without_dtype:
|
||||
available_itemsize = None
|
||||
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
|
||||
key = key_without_dtype + dtype_suffix
|
||||
|
||||
if key in tensor_dict:
|
||||
if available_itemsize is None:
|
||||
available_itemsize = dtype.itemsize
|
||||
elif available_itemsize > dtype.itemsize:
|
||||
# if higher precision latents are already cached, remove lower precision latents
|
||||
del tensor_dict[key]
|
||||
|
||||
|
||||
def get_compatible_dtype_keys(
|
||||
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> list[Optional[str]]:
|
||||
"""
|
||||
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
|
||||
If the key is not found, it returns None.
|
||||
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
|
||||
|
||||
:param dict_keys: set of keys in the dictionary
|
||||
:param keys_without_dtype: list of keys without dtype suffix to check
|
||||
:param dtype: dtype to check, or None for any dtype
|
||||
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
|
||||
"""
|
||||
compatible_dtypes = get_compatible_dtypes(dtype)
|
||||
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
|
||||
|
||||
available_keys = []
|
||||
for key_without_dtype in keys_without_dtype:
|
||||
available_key = None
|
||||
if key_without_dtype in dict_keys:
|
||||
available_key = key_without_dtype
|
||||
else:
|
||||
for dtype_suffix in dtype_suffixes:
|
||||
key = key_without_dtype + dtype_suffix
|
||||
if key in dict_keys:
|
||||
available_key = key
|
||||
break
|
||||
available_keys.append(available_key)
|
||||
|
||||
return available_keys
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
@@ -324,17 +397,26 @@ class TextEncoderOutputsCachingStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
architecture: str,
|
||||
cache_to_disk: bool,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool = False,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
|
||||
"""
|
||||
self._architecture = architecture
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
self._max_token_length = max_token_length
|
||||
self._masked = masked
|
||||
self._is_partial = is_partial
|
||||
self._is_weighted = is_weighted
|
||||
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
@@ -346,6 +428,18 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def architecture(self):
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def max_token_length(self):
|
||||
return self._max_token_length
|
||||
|
||||
@property
|
||||
def masked(self):
|
||||
return self._masked
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -354,6 +448,11 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
suffix_masked = "_m" if self.masked else ""
|
||||
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
|
||||
|
||||
@property
|
||||
def is_partial(self):
|
||||
return self._is_partial
|
||||
@@ -362,31 +461,159 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def is_weighted(self):
|
||||
return self._is_weighted
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
def get_cache_path(self, absolute_path: str) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + self.cache_suffix
|
||||
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
|
||||
"""
|
||||
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
|
||||
all dtype tensors are returned, because cache validation is done in advance.
|
||||
"""
|
||||
with safe_open(cache_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, version.split("."))
|
||||
if major > 1: # or (major == 1 and minor > 0):
|
||||
if not self.load_version_warning_printed:
|
||||
self.load_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
|
||||
dict_keys = f.keys()
|
||||
results = []
|
||||
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
|
||||
for key in compatible_keys:
|
||||
results.append(f.get_tensor(key) if key is not None else None)
|
||||
|
||||
return results
|
||||
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
|
||||
"""
|
||||
masked: may be False even if self.masked is True. It is False for some outputs.
|
||||
"""
|
||||
key_suffix = f"_{prompt_id}"
|
||||
if dtype is not None and dtype.is_floating_point: # float tensor only
|
||||
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
|
||||
return key_suffix
|
||||
|
||||
def get_compatible_output_keys(
|
||||
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> list[Optional[str], Optional[str]]:
|
||||
"""
|
||||
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
|
||||
"""
|
||||
key_suffix = self.get_key_suffix(caption_index, None)
|
||||
keys_without_dtype = [k + key_suffix for k in base_keys]
|
||||
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
|
||||
|
||||
def _default_is_disk_cached_outputs_expected(
|
||||
self,
|
||||
cache_path: str,
|
||||
captions: list[str],
|
||||
base_keys: list[tuple[str, bool]],
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(cache_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
keys = f.keys()
|
||||
metadata = f.metadata()
|
||||
|
||||
# check captions in metadata
|
||||
for i, caption in enumerate(captions):
|
||||
if metadata.get(f"caption{i+1}") != caption:
|
||||
return False
|
||||
|
||||
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
|
||||
if any(key is None for key in compatible_keys):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {cache_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: list[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
|
||||
tensor_dict = {}
|
||||
|
||||
overwrite = False
|
||||
if os.path.exists(cache_path):
|
||||
# load existing safetensors and update it
|
||||
overwrite = True
|
||||
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
metadata = f.metadata()
|
||||
keys = f.keys()
|
||||
for key in keys:
|
||||
tensor_dict[key] = f.get_tensor(key)
|
||||
assert metadata["architecture"] == self.architecture
|
||||
|
||||
file_version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, file_version.split("."))
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
self.save_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
else:
|
||||
metadata = {}
|
||||
metadata["architecture"] = self.architecture
|
||||
metadata["format_version"] = "1.0.0"
|
||||
|
||||
metadata[f"caption{caption_index+1}"] = caption
|
||||
|
||||
for key, output in zip(keys, outputs):
|
||||
dtype = output.dtype # long or one of float
|
||||
key_suffix = self.get_key_suffix(caption_index, dtype)
|
||||
tensor_dict[key + key_suffix] = output
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
suffix_without_dtype = self.get_key_suffix(caption_index, None)
|
||||
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
def __init__(
|
||||
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
) -> None:
|
||||
self._architecture = architecture
|
||||
self._latents_stride = latents_stride
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
self.load_version_warning_printed = False
|
||||
self.save_version_warning_printed = False
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
@@ -397,6 +624,14 @@ class LatentsCachingStrategy:
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def architecture(self):
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def latents_stride(self):
|
||||
return self._latents_stride
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -407,54 +642,126 @@ class LatentsCachingStrategy:
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
raise NotImplementedError
|
||||
return f"_{self.architecture.lower()}.safetensors"
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
|
||||
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
|
||||
|
||||
def get_latents_cache_path(
|
||||
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
|
||||
) -> str:
|
||||
if cache_dir is not None:
|
||||
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
|
||||
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
|
||||
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
|
||||
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
|
||||
else:
|
||||
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
|
||||
else:
|
||||
cache_file_base = absolute_path_or_archive_img_path
|
||||
|
||||
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_key_suffix(
|
||||
self,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
if dtype is None, it returns "_32x64" for example.
|
||||
"""
|
||||
if latents_size is not None:
|
||||
expected_latents_size = latents_size # H, W
|
||||
else:
|
||||
# bucket_reso is (W, H)
|
||||
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
|
||||
|
||||
if dtype is None:
|
||||
dtype_suffix = ""
|
||||
else:
|
||||
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
|
||||
|
||||
# e.g. "_32x64_float16", HxW, dtype
|
||||
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
|
||||
|
||||
return key_suffix
|
||||
|
||||
def get_compatible_latents_keys(
|
||||
self,
|
||||
keys: set[str],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
flip_aug: bool,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
) -> list[Optional[str], Optional[str]]:
|
||||
"""
|
||||
bucket_reso is (W, H), latents_size is (H, W)
|
||||
"""
|
||||
|
||||
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
|
||||
keys_without_dtype = ["latents" + key_suffix]
|
||||
if flip_aug:
|
||||
keys_without_dtype.append("latents_flipped" + key_suffix)
|
||||
|
||||
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
|
||||
return compatible_keys if flip_aug else compatible_keys[0] + [None]
|
||||
|
||||
def _default_is_disk_cached_latents_expected(
|
||||
self,
|
||||
latents_stride: int,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
latents_cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
):
|
||||
# multi_resolution is always enabled for any strategy
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
if not os.path.exists(latents_cache_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
# e.g. "_32x64", HxW
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
||||
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
# safe_open locks the file, so we cannot use it for checking keys
|
||||
# with safe_open(latents_cache_path, framework="pt") as f:
|
||||
# keys = f.keys()
|
||||
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
|
||||
keys = f.keys()
|
||||
|
||||
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
|
||||
# print(f"alpha_mask not found: {latents_cache_path}")
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
|
||||
# preferred_dtype is None if any dtype is acceptable
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
|
||||
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
|
||||
)
|
||||
if latents_key is None or (flip_aug and flipped_latents_key is None):
|
||||
# print(f"Precise dtype not found: {latents_cache_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
logger.error(f"Error loading file: {latents_cache_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
@@ -465,11 +772,10 @@ class LatentsCachingStrategy:
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos: List,
|
||||
image_infos: List[utils.ImageInfo],
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
@@ -499,13 +805,8 @@ class LatentsCachingStrategy:
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
latents_size = latents.shape[1:3] # H, W
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||
|
||||
if self.cache_to_disk:
|
||||
self.save_latents_to_disk(
|
||||
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
|
||||
)
|
||||
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||||
else:
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
@@ -515,56 +816,96 @@ class LatentsCachingStrategy:
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
with safe_open(cache_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, version.split("."))
|
||||
if major > 1: # or (major == 1 and minor > 0):
|
||||
if not self.load_version_warning_printed:
|
||||
self.load_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
keys = f.keys()
|
||||
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
|
||||
|
||||
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
|
||||
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
|
||||
|
||||
latents = f.get_tensor(latents_key)
|
||||
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
|
||||
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
|
||||
|
||||
original_size = [int(metadata["width"]), int(metadata["height"])]
|
||||
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
|
||||
crop_ltrb = list(map(int, crop_ltrb.split(",")))
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self,
|
||||
npz_path,
|
||||
latents_tensor,
|
||||
original_size,
|
||||
crop_ltrb,
|
||||
flipped_latents_tensor=None,
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
cache_path: str,
|
||||
latents_tensor: torch.Tensor,
|
||||
original_size: Tuple[int, int],
|
||||
crop_ltrb: List[int],
|
||||
flipped_latents_tensor: Optional[torch.Tensor] = None,
|
||||
alpha_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
kwargs = {}
|
||||
dtype = latents_tensor.dtype
|
||||
latents_size = latents_tensor.shape[1:3] # H, W
|
||||
tensor_dict = {}
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
# load existing npz and update it
|
||||
npz = np.load(npz_path)
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
overwrite = False
|
||||
if os.path.exists(cache_path):
|
||||
# load existing safetensors and update it
|
||||
overwrite = True
|
||||
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
# we cannot use safe_open here because it locks the file
|
||||
# with safe_open(cache_path, framework="pt") as f:
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
metadata = f.metadata()
|
||||
keys = f.keys()
|
||||
for key in keys:
|
||||
tensor_dict[key] = f.get_tensor(key)
|
||||
assert metadata["architecture"] == self.architecture
|
||||
|
||||
file_version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, file_version.split("."))
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
self.save_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
else:
|
||||
metadata = {}
|
||||
metadata["architecture"] = self.architecture
|
||||
metadata["width"] = f"{original_size[0]}"
|
||||
metadata["height"] = f"{original_size[1]}"
|
||||
metadata["format_version"] = "1.0.0"
|
||||
|
||||
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
|
||||
|
||||
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
|
||||
if latents_tensor is not None:
|
||||
tensor_dict["latents" + key_suffix] = latents_tensor
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
|
||||
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(npz_path, **kwargs)
|
||||
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
|
||||
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
|
||||
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
@@ -5,9 +5,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -15,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_utils, train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||
KEYS = ["l_pooled", "t5_out", "txt_ids"]
|
||||
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
super().__init__(
|
||||
FluxLatentsCachingStrategy.ARCHITECTURE,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked,
|
||||
is_partial,
|
||||
)
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
):
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
txt_ids = data["txt_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)[0]
|
||||
else:
|
||||
t5_attn_mask = None
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
@@ -154,80 +145,67 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
l_pooled = l_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
txt_ids = txt_ids.cpu()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu()
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
|
||||
if self.masked:
|
||||
outputs += [t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
||||
ARCHITECTURE = "flux"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -13,6 +11,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
|
||||
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
@@ -134,33 +134,30 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||
ARCHITECTURE_SD = "sd"
|
||||
ARCHITECTURE_SDXL = "sdxl"
|
||||
|
||||
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
|
||||
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.sd = sd
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
) -> bool:
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
def load_latents_from_disk(
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
@@ -6,10 +6,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -17,6 +13,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -254,7 +253,8 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||
KEYS = ["lg_out", "t5_out", "lg_pooled"]
|
||||
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -262,70 +262,51 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_lg_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
max_token_length: int = 256,
|
||||
masked: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
"""
|
||||
apply_lg_attn_mask and apply_t5_attn_mask must be same
|
||||
"""
|
||||
super().__init__(
|
||||
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked=masked,
|
||||
is_partial=is_partial,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "lg_out" not in npz:
|
||||
return False
|
||||
if "lg_pooled" not in npz:
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)
|
||||
else:
|
||||
l_attn_mask = g_attn_mask = t5_attn_mask = None
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
@@ -334,87 +315,76 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
apply_lg_attn_mask=self.masked,
|
||||
apply_t5_attn_mask=self.masked,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
lg_out = lg_out.cpu()
|
||||
lg_pooled = lg_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
l_attn_mask = tokens_and_masks[3].cpu()
|
||||
g_attn_mask = tokens_and_masks[4].cpu()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
|
||||
if self.masked:
|
||||
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [
|
||||
lg_out_i,
|
||||
t5_out_i,
|
||||
lg_pooled_i,
|
||||
l_attn_mask_i,
|
||||
g_attn_mask_i,
|
||||
t5_attn_mask_i,
|
||||
]
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
ARCHITECTURE_SD3 = "sd3"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -14,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
from library import utils
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -21,6 +21,9 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
|
||||
"""
|
||||
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
||||
@@ -220,51 +223,51 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
ARCHITECTURE_SDXL = "sdxl"
|
||||
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: Optional[int] = None,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
||||
"""
|
||||
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
|
||||
"""
|
||||
max_token_length = max_token_length or 75
|
||||
super().__init__(
|
||||
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
is_partial,
|
||||
is_weighted,
|
||||
max_token_length=max_token_length,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
# SDXL does not support attn mask
|
||||
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_state1 = data["hidden_state1"]
|
||||
hidden_state2 = data["hidden_state2"]
|
||||
pool2 = data["pool2"]
|
||||
return [hidden_state1, hidden_state2, pool2]
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
|
||||
@@ -279,28 +282,24 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy, models, [tokens1, tokens2]
|
||||
)
|
||||
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
hidden_state1 = hidden_state1.cpu()
|
||||
hidden_state2 = hidden_state2.cpu()
|
||||
pool2 = pool2.cpu()
|
||||
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
self.save_outputs_to_disk(
|
||||
info.text_encoder_outputs_cache_path,
|
||||
caption_index,
|
||||
caption,
|
||||
SdxlTextEncoderOutputsCachingStrategy.KEYS,
|
||||
[hidden_state1_i, hidden_state2_i, pool2_i],
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,62 @@ def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.captions: Optional[list[str]] = None
|
||||
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
|
||||
self.list_of_tags: Optional[list[str]] = None
|
||||
self.tags_weights: Optional[list[float]] = None
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.latents_cache_dir: Optional[str] = None
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_cache_path: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
# crop left top right bottom in original pixel size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
|
||||
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
|
||||
# old
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
|
||||
|
||||
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
|
||||
self.list_of_tags = list_of_tags
|
||||
|
||||
def set_fine_tuning_info(
|
||||
self,
|
||||
captions: Optional[list[str]],
|
||||
caption_weights: Optional[list[float]],
|
||||
list_of_tags: Optional[list[str]],
|
||||
tags_weights: Optional[list[float]],
|
||||
image_size: Tuple[int, int],
|
||||
latents_cache_dir: Optional[str],
|
||||
):
|
||||
self.captions = captions
|
||||
self.caption_weights = caption_weights
|
||||
self.list_of_tags = list_of_tags
|
||||
self.tags_weights = tags_weights
|
||||
self.image_size = image_size
|
||||
self.latents_cache_dir = latents_cache_dir
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
@@ -189,6 +245,15 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
|
||||
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
||||
|
||||
# get name of the dtype
|
||||
dtype_name = str(dtype).split(".")[-1]
|
||||
|
||||
return dtype_name
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
@@ -264,8 +329,8 @@ class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -276,6 +341,9 @@ class MemoryEfficientSafeOpen:
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
@@ -293,10 +361,9 @@ class MemoryEfficientSafeOpen:
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
||||
header_json = f.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
|
||||
14
sd3_train.py
14
sd3_train.py
@@ -75,6 +75,12 @@ def train(args):
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
|
||||
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
|
||||
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
|
||||
)
|
||||
|
||||
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
|
||||
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
|
||||
+ " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)"
|
||||
@@ -169,8 +175,8 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
args.t5xxl_max_token_length,
|
||||
args.apply_lg_attn_mask,
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -279,8 +285,8 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial
|
||||
args.t5xxl_max_token_length,
|
||||
args.apply_lg_attn_mask,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
@@ -331,7 +337,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -43,6 +43,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
|
||||
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
|
||||
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
|
||||
)
|
||||
|
||||
# prepare CLIP-L/CLIP-G/T5XXL training flags
|
||||
self.train_clip = not args.network_train_unet_only
|
||||
@@ -188,8 +192,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip or self.train_t5xxl,
|
||||
max_token_length=args.t5xxl_max_token_length,
|
||||
apply_lg_attn_mask=args.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -273,7 +273,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -322,7 +322,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -223,7 +223,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -195,7 +195,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -83,7 +83,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -150,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
# cache latents with dataset
|
||||
# TODO use DataLoader to speed up
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents to disk.")
|
||||
|
||||
@@ -157,7 +157,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -559,9 +559,9 @@ class NetworkTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.new_cache_latents(vae, accelerator)
|
||||
val_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -382,7 +382,7 @@ class TextualInversionTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user