Compare commits

...

10 Commits

Author SHA1 Message Date
Kohya S
f3a85060ef Merge branch 'sd3' into new_cache 2025-02-19 21:13:08 +09:00
Kohya S
f2322a23e2 feat: update fine tuning dataset 2024-12-09 20:52:18 +09:00
Kohya S
70423ec61d Merge branch 'sd3' into new_cache 2024-12-09 18:36:10 +09:00
Kohya S
28e9352cc5 feat: Florence-2 captioninig (WIP) 2024-12-05 22:04:37 +09:00
Kohya S
b72b9eaf11 Merge branch 'sd3' into new_cache 2024-12-04 20:44:42 +09:00
Kohya S
744cf03136 fix to work 2024-11-29 21:59:25 +09:00
Kohya S
2238b94e7b support new metadata in wd14tagger (WIP), fix typo 2024-11-28 21:05:17 +09:00
Kohya S
665c04e649 Merge branch 'sd3' into new_cache 2024-11-27 12:57:32 +09:00
Kohya S
3677094256 Text Encoder cache (WIP) 2024-11-27 12:57:04 +09:00
Kohya S
bdac55ebbc feat: refactor latent cache format 2024-11-15 21:16:49 +09:00
25 changed files with 1629 additions and 1326 deletions

View File

@@ -178,7 +178,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -0,0 +1,232 @@
# add caption to images by Florence-2
import argparse
import json
import os
import glob
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM
from library import device_utils, train_util, dataset_metadata_utils
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import tagger_utils
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
device = args.device if args.device is not None else device_utils.get_preferred_device()
if type(device) is str:
device = torch.device(device)
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
logger.info(f"device: {device}, dtype: {torch_dtype}")
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
support_flash_attn = False
try:
import flash_attn
support_flash_attn = True
except ImportError:
pass
if support_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
else:
logger.info(
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
)
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
from transformers.dynamic_module_utils import get_imports
from unittest.mock import patch
def fixed_get_imports(filename) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
# workaround for unnecessary flash_attn requirement
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
model.eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# 画像を読み込む
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
# load metadata if needed
if args.metadata is not None:
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# define preprocess_image function
def preprocess_image(image: Image.Image):
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
return inputs
# prepare DataLoader or something similar :)
# Loader returns: list of (image_path, processed_image_or_something, image_size)
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
else:
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
def run_batch(
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
caption_index: Optional[int] = None,
):
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
if args.debug:
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
max_new_tokens=args.max_new_tokens,
num_beams=args.num_beams,
)
if args.debug:
logger.info(f"generate done: {generated_ids.shape}")
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
if args.debug:
logger.info(f"decode done: {len(generated_texts)}")
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
caption_text = caption_text.strip().replace("<pad>", "")
original_caption_text = caption_text
if args.remove_mood:
p = caption_text.find("The overall ")
if p != -1:
caption_text = caption_text[:p].strip()
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(caption_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "caption" not in image_md:
image_md["caption"] = []
if caption_index is None:
image_md["caption"].append(caption_text)
else:
while len(image_md["caption"]) <= caption_index:
image_md["caption"].append("")
image_md["caption"][caption_index] = caption_text
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tCaption: {caption_text}")
if args.remove_mood and original_caption_text != caption_text:
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
for data_entry in tqdm(loader, smoothing=0.0):
b_imgs = data_entry
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.caption_index)
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
parser.add_argument(
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
)
parser.add_argument(
"--num_beams",
type=int,
default=3,
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
)
parser.add_argument(
"--caption_index",
type=int,
default=None,
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
tagger_utils.add_archive_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)

View File

@@ -180,7 +180,7 @@ def main(args):
# バッチへ追加
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_npz = npz_file_name
image_info.latents_cache_path = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image

View File

@@ -1,7 +1,10 @@
import argparse
import csv
import glob
import json
import os
from pathlib import Path
from typing import Any, Optional
import cv2
import numpy as np
@@ -10,14 +13,18 @@ from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library import dataset_metadata_utils
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
from library.utils import pil_resize
import tagger_utils
# from wd14 tagger
IMAGE_SIZE = 448
@@ -63,13 +70,14 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
size = image.size
image = preprocess_image(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (image, img_path)
return (image, img_path, size)
def collate_fn_remove_corrupted(batch):
@@ -83,6 +91,10 @@ def collate_fn_remove_corrupted(batch):
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
@@ -149,15 +161,19 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
provider_options=[{"device_type": "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
),
)
else:
@@ -203,7 +219,9 @@ def main(args):
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
assert (
len(tags) == 2
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
@@ -216,9 +234,15 @@ def main(args):
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
tag_freq = {}
@@ -231,19 +255,23 @@ def main(args):
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
def run_batch(
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
tags_index: Optional[int] = None,
):
imgs = np.array([im for _, im, _ in list_of_path_img_size])
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
probs = probs[: len(list_of_path_img_size)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
@@ -265,7 +293,7 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
@@ -281,7 +309,7 @@ def main(args):
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
@@ -304,12 +332,24 @@ def main(args):
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
existing_content = None
if images_metadata is None:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
else:
image_md = images_metadata.get(image_path, None)
if image_md is not None:
tags = image_md.get("tags", None)
if tags is not None:
if tags_index is None and len(tags) > 0:
existing_content = tags[-1]
elif tags_index is not None and tags_index < len(tags):
existing_content = tags[tags_index]
if existing_content is not None:
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
@@ -319,19 +359,46 @@ def main(args):
# Create new tag_text
tag_text = caption_separator.join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "tags" not in image_md:
image_md["tags"] = []
if tags_index is None:
image_md["tags"].append(tag_text)
else:
while len(image_md["tags"]) <= tags_index:
image_md["tags"].append("")
image_md["tags"][tags_index] = tag_text
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# load metadata if needed
if args.metadata is not None:
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(
loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
@@ -339,35 +406,37 @@ def main(args):
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
use_loader = True
else:
data = [[(None, ip)] for ip in image_paths]
# make batch of image paths
loader = []
for i in range(0, len(image_paths), args.batch_size):
loader.append(image_paths[i : i + args.batch_size])
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
for data_entry in tqdm(loader, smoothing=0.0):
if use_loader:
b_imgs = data_entry
else:
b_imgs = []
for image_path in data_entry:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
size = image.size
image = preprocess_image(image)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
b_imgs.append((image_path, image, size))
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.tags_index)
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
@@ -380,9 +449,7 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--repo_id",
type=str,
@@ -400,9 +467,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -441,9 +506,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--undesired_tags",
type=str,
@@ -453,20 +516,24 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
"--use_rating_tags",
action="store_true",
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
"--use_rating_tags_as_last_tag",
action="store_true",
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
"--character_tags_first",
action="store_true",
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
@@ -495,6 +562,15 @@ def setup_parser() -> argparse.ArgumentParser:
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
parser.add_argument(
"--tags_index",
type=int,
default=None,
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
)
tagger_utils.add_archive_arguments(parser)
return parser

150
finetune/tagger_utils.py Normal file
View File

@@ -0,0 +1,150 @@
import argparse
import json
import math
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Union
import zipfile
import tarfile
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import dataset_metadata_utils, train_util
class ArchiveImageLoader:
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.archive_paths = archive_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.current_archive = None
self.archive_index = 0
self.image_index = 0
self.files = None
self.executor = ThreadPoolExecutor()
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
def __iter__(self):
return self
def __next__(self):
images = []
while len(images) < self.batch_size:
if self.current_archive is None:
if self.archive_index >= len(self.archive_paths):
if len(images) == 0:
raise StopIteration
else:
break # return the remaining images
if self.debug:
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
current_archive_path = self.archive_paths[self.archive_index]
if current_archive_path.endswith(".zip"):
self.current_archive = zipfile.ZipFile(current_archive_path)
self.files = self.current_archive.namelist()
elif current_archive_path.endswith(".tar"):
self.current_archive = tarfile.open(current_archive_path, "r")
self.files = self.current_archive.getnames()
else:
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
self.image_index = 0
# filter by image extensions
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
file = self.files[self.image_index]
archive_and_image_path = (
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
)
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
with archive.open(file) as f:
image = Image.open(f).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
# wait for all new_images to load to close the archive
new_images = [(image_path, future.result()) for image_path, future in new_images]
if self.image_index >= len(self.files):
self.current_archive.close()
self.current_archive = None
self.archive_index += 1
images.extend(new_images)
return [(image_path, image, size) for image_path, (image, size) in images]
class ImageLoader:
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.image_paths = image_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.image_index = 0
self.executor = ThreadPoolExecutor()
def __len__(self):
return math.ceil(len(self.image_paths) / self.batch_size)
def __iter__(self):
return self
def __next__(self):
if self.image_index >= len(self.image_paths):
raise StopIteration
images = []
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
def load_image(file):
image = Image.open(file).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
image_path = self.image_paths[self.image_index]
images.append((image_path, self.executor.submit(load_image, image_path)))
self.image_index += 1
images = [(image_path, future.result()) for image_path, future in images]
return [(image_path, image, size) for image_path, (image, size) in images]
def add_archive_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--metadata",
type=str,
default=None,
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
)
parser.add_argument(
"--load_archive",
action="store_true",
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
)

View File

@@ -152,15 +152,20 @@ def train(args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
@@ -199,7 +204,7 @@ def train(args):
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator)
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
@@ -237,7 +242,12 @@ def train(args):
t5xxl.to(accelerator.device)
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

View File

@@ -11,16 +11,6 @@ from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
flux_models,
flux_train_utils,
flux_utils,
sd3_train_utils,
strategy_base,
strategy_flux,
train_util,
)
from library.utils import setup_logging
setup_logging()
@@ -28,6 +18,9 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -185,13 +178,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None

View File

@@ -0,0 +1,58 @@
import os
import json
from typing import Any, Optional
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
METADATA_VERSION = [1, 0, 0]
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
ARCHIVE_PATH_SEPARATOR = "////"
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
if os.path.exists(metadata_file):
logger.info(f"loading metadata file: {metadata_file}")
with open(metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
if not create_new:
return None
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
metadata = {"format_version": VERSION_STRING, "images": {}}
return metadata
def is_archive_path(archive_and_image_path: str) -> bool:
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
def get_inner_path(archive_and_image_path: str) -> str:
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
def get_archive_digest(archive_and_image_path: str) -> str:
"""
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
"""
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"

View File

@@ -2,16 +2,14 @@
import os
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from safetensors.torch import safe_open, save_file
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
@@ -19,6 +17,81 @@ import logging
logger = logging.getLogger(__name__)
from library import dataset_metadata_utils, utils
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
if dtype is None:
# all dtypes are acceptable
return get_available_dtypes()
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
compatible_dtypes = [torch.float32]
if dtype.itemsize == 1: # fp8
compatible_dtypes.append(torch.bfloat16)
compatible_dtypes.append(torch.float16)
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
return compatible_dtypes
def get_available_dtypes() -> List[torch.dtype]:
"""
Returns the list of available dtypes for latents caching. Higher precision is preferred.
"""
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
"""
Removes lower precision values from tensor_dict.
"""
available_dtypes = get_available_dtypes()
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
for key_without_dtype in keys_without_dtype:
available_itemsize = None
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
key = key_without_dtype + dtype_suffix
if key in tensor_dict:
if available_itemsize is None:
available_itemsize = dtype.itemsize
elif available_itemsize > dtype.itemsize:
# if higher precision latents are already cached, remove lower precision latents
del tensor_dict[key]
def get_compatible_dtype_keys(
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str]]:
"""
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
If the key is not found, it returns None.
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
:param dict_keys: set of keys in the dictionary
:param keys_without_dtype: list of keys without dtype suffix to check
:param dtype: dtype to check, or None for any dtype
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
"""
compatible_dtypes = get_compatible_dtypes(dtype)
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
available_keys = []
for key_without_dtype in keys_without_dtype:
available_key = None
if key_without_dtype in dict_keys:
available_key = key_without_dtype
else:
for dtype_suffix in dtype_suffixes:
key = key_without_dtype + dtype_suffix
if key in dict_keys:
available_key = key
break
available_keys.append(available_key)
return available_keys
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -324,17 +397,26 @@ class TextEncoderOutputsCachingStrategy:
def __init__(
self,
architecture: str,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool = False,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
"""
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
"""
self._architecture = architecture
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._max_token_length = max_token_length
self._masked = masked
self._is_partial = is_partial
self._is_weighted = is_weighted
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
@classmethod
def set_strategy(cls, strategy):
@@ -346,6 +428,18 @@ class TextEncoderOutputsCachingStrategy:
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def max_token_length(self):
return self._max_token_length
@property
def masked(self):
return self._masked
@property
def cache_to_disk(self):
return self._cache_to_disk
@@ -354,6 +448,11 @@ class TextEncoderOutputsCachingStrategy:
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
suffix_masked = "_m" if self.masked else ""
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
@property
def is_partial(self):
return self._is_partial
@@ -362,31 +461,159 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self):
return self._is_weighted
def get_outputs_npz_path(self, image_abs_path: str) -> str:
def get_cache_path(self, absolute_path: str) -> str:
return os.path.splitext(absolute_path)[0] + self.cache_suffix
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
"""
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
all dtype tensors are returned, because cache validation is done in advance.
"""
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
dict_keys = f.keys()
results = []
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
for key in compatible_keys:
results.append(f.get_tensor(key) if key is not None else None)
return results
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
"""
masked: may be False even if self.masked is True. It is False for some outputs.
"""
key_suffix = f"_{prompt_id}"
if dtype is not None and dtype.is_floating_point: # float tensor only
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
return key_suffix
def get_compatible_output_keys(
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str], Optional[str]]:
"""
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
"""
key_suffix = self.get_key_suffix(caption_index, None)
keys_without_dtype = [k + key_suffix for k in base_keys]
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
def _default_is_disk_cached_outputs_expected(
self,
cache_path: str,
captions: list[str],
base_keys: list[tuple[str, bool]],
preferred_dtype: Optional[Union[str, torch.dtype]],
):
if not self.cache_to_disk:
return False
if not os.path.exists(cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
keys = f.keys()
metadata = f.metadata()
# check captions in metadata
for i, caption in enumerate(captions):
if metadata.get(f"caption{i+1}") != caption:
return False
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
if any(key is None for key in compatible_keys):
return False
except Exception as e:
logger.error(f"Error loading file: {cache_path}")
raise e
return True
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
self,
tokenize_strategy: TokenizeStrategy,
models: list[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
raise NotImplementedError
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
tensor_dict = {}
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["format_version"] = "1.0.0"
metadata[f"caption{caption_index+1}"] = caption
for key, output in zip(keys, outputs):
dtype = output.dtype # long or one of float
key_suffix = self.get_key_suffix(caption_index, dtype)
tensor_dict[key + key_suffix] = output
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(caption_index, None)
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
def __init__(
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
self._architecture = architecture
self._latents_stride = latents_stride
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self.load_version_warning_printed = False
self.save_version_warning_printed = False
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
@@ -397,6 +624,14 @@ class LatentsCachingStrategy:
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def latents_stride(self):
return self._latents_stride
@property
def cache_to_disk(self):
return self._cache_to_disk
@@ -407,54 +642,126 @@ class LatentsCachingStrategy:
@property
def cache_suffix(self):
raise NotImplementedError
return f"_{self.architecture.lower()}.safetensors"
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
def get_latents_cache_path(
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
) -> str:
if cache_dir is not None:
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
else:
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
else:
cache_file_base = absolute_path_or_archive_img_path
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[Union[str, torch.dtype]],
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def get_key_suffix(
self,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
) -> str:
"""
if dtype is None, it returns "_32x64" for example.
"""
if latents_size is not None:
expected_latents_size = latents_size # H, W
else:
# bucket_reso is (W, H)
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
if dtype is None:
dtype_suffix = ""
else:
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
# e.g. "_32x64_float16", HxW, dtype
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
return key_suffix
def get_compatible_latents_keys(
self,
keys: set[str],
dtype: Optional[Union[str, torch.dtype]],
flip_aug: bool,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
) -> list[Optional[str], Optional[str]]:
"""
bucket_reso is (W, H), latents_size is (H, W)
"""
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
keys_without_dtype = ["latents" + key_suffix]
if flip_aug:
keys_without_dtype.append("latents_flipped" + key_suffix)
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
return compatible_keys if flip_aug else compatible_keys[0] + [None]
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
latents_cache_path: str,
flip_aug: bool,
alpha_mask: bool,
multi_resolution: bool = False,
preferred_dtype: Optional[Union[str, torch.dtype]],
):
# multi_resolution is always enabled for any strategy
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
if not os.path.exists(latents_cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
try:
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
# safe_open locks the file, so we cannot use it for checking keys
# with safe_open(latents_cache_path, framework="pt") as f:
# keys = f.keys()
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
keys = f.keys()
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
# print(f"alpha_mask not found: {latents_cache_path}")
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
# preferred_dtype is None if any dtype is acceptable
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
)
if latents_key is None or (flip_aug and flipped_latents_key is None):
# print(f"Precise dtype not found: {latents_cache_path}")
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
logger.error(f"Error loading file: {latents_cache_path}")
raise e
return True
@@ -465,11 +772,10 @@ class LatentsCachingStrategy:
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List,
image_infos: List[utils.ImageInfo],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
@@ -499,13 +805,8 @@ class LatentsCachingStrategy:
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -515,56 +816,96 @@ class LatentsCachingStrategy:
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
raise NotImplementedError
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
keys = f.keys()
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
latents = f.get_tensor(latents_key)
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
original_size = [int(metadata["width"]), int(metadata["height"])]
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
crop_ltrb = list(map(int, crop_ltrb.split(",")))
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
cache_path: str,
latents_tensor: torch.Tensor,
original_size: Tuple[int, int],
crop_ltrb: List[int],
flipped_latents_tensor: Optional[torch.Tensor] = None,
alpha_mask: Optional[torch.Tensor] = None,
):
kwargs = {}
dtype = latents_tensor.dtype
latents_size = latents_tensor.shape[1:3] # H, W
tensor_dict = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
# we cannot use safe_open here because it locks the file
# with safe_open(cache_path, framework="pt") as f:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["width"] = f"{original_size[0]}"
metadata["height"] = f"{original_size[1]}"
metadata["format_version"] = "1.0.0"
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
if latents_tensor is not None:
tensor_dict["latents" + key_suffix] = latents_tensor
if flipped_latents_tensor is not None:
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)

View File

@@ -5,9 +5,6 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -15,6 +12,8 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_utils, train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
KEYS = ["l_pooled", "t5_out", "txt_ids"]
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
super().__init__(
FluxLatentsCachingStrategy.ARCHITECTURE,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked,
is_partial,
)
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
):
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
)[0]
else:
t5_attn_mask = None
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
@@ -154,80 +145,67 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = tokens_and_masks[2].cpu()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, info in enumerate(infos):
for i, (info, caption_index, caption) in enumerate(batch):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
if self.masked:
outputs += [t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
ARCHITECTURE = "flux"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
@@ -13,6 +11,8 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
@@ -134,33 +134,30 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
ARCHITECTURE_SD = "sd"
ARCHITECTURE_SDXL = "sdxl"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
) -> bool:
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype

View File

@@ -6,10 +6,6 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -17,6 +13,9 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -254,7 +253,8 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
KEYS = ["lg_out", "t5_out", "lg_pooled"]
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
def __init__(
self,
@@ -262,70 +262,51 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
max_token_length: int = 256,
masked: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
"""
apply_lg_attn_mask and apply_t5_attn_mask must be same
"""
super().__init__(
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked=masked,
is_partial=is_partial,
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"]
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
)
else:
l_attn_mask = g_attn_mask = t5_attn_mask = None
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
@@ -334,87 +315,76 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
apply_lg_attn_mask=self.masked,
apply_t5_attn_mask=self.masked,
enable_dropout=False,
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu()
lg_pooled = lg_pooled.cpu()
t5_out = t5_out.cpu()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu()
g_attn_mask = tokens_and_masks[4].cpu()
t5_attn_mask = tokens_and_masks[5].cpu()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, (info, caption_index, caption) in enumerate(batch):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
if self.masked:
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [
lg_out_i,
t5_out_i,
lg_pooled_i,
l_attn_mask_i,
g_attn_mask_i,
t5_attn_mask_i,
]
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
ARCHITECTURE_SD3 = "sd3"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
@@ -14,6 +12,8 @@ import logging
logger = logging.getLogger(__name__)
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library import utils
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -21,6 +21,9 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
@@ -220,51 +223,51 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
ARCHITECTURE_SDXL = "sdxl"
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: Optional[int] = None,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
"""
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
max_token_length = max_token_length or 75
super().__init__(
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
is_weighted,
max_token_length=max_token_length,
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
# SDXL does not support attn mask
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
@@ -279,28 +282,24 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
for i, (info, caption_index, caption) in enumerate(batch):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
self.save_outputs_to_disk(
info.text_encoder_outputs_cache_path,
caption_index,
caption,
SdxlTextEncoderOutputsCachingStrategy.KEYS,
[hidden_state1_i, hidden_state2_i, pool2_i],
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,62 @@ def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.captions: Optional[list[str]] = None
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
self.list_of_tags: Optional[list[str]] = None
self.tags_weights: Optional[list[float]] = None
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.latents_cache_dir: Optional[str] = None
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_cache_path: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
# crop left top right bottom in original pixel size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
def __str__(self) -> str:
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
self.list_of_tags = list_of_tags
def set_fine_tuning_info(
self,
captions: Optional[list[str]],
caption_weights: Optional[list[float]],
list_of_tags: Optional[list[str]],
tags_weights: Optional[list[float]],
image_size: Tuple[int, int],
latents_cache_dir: Optional[str],
):
self.captions = captions
self.caption_weights = caption_weights
self.list_of_tags = list_of_tags
self.tags_weights = tags_weights
self.image_size = image_size
self.latents_cache_dir = latents_cache_dir
# region Logging
@@ -189,6 +245,15 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
raise ValueError(f"Unsupported dtype: {s}")
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
# get name of the dtype
dtype_name = str(dtype).split(".")[-1]
return dtype_name
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
@@ -264,8 +329,8 @@ class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
@@ -276,6 +341,9 @@ class MemoryEfficientSafeOpen:
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
@@ -293,10 +361,9 @@ class MemoryEfficientSafeOpen:
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])

View File

@@ -75,6 +75,12 @@ def train(args):
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
+ " / text encoderの学習時はtext encoderの出力はキャッシュできませんt5xxlのみキャッシュすることは可能です"
@@ -169,8 +175,8 @@ def train(args):
args.text_encoder_batch_size,
False,
False,
False,
False,
args.t5xxl_max_token_length,
args.apply_lg_attn_mask,
)
)
train_dataset_group.set_current_strategies()
@@ -279,8 +285,8 @@ def train(args):
args.text_encoder_batch_size,
args.skip_cache_check,
train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial
args.t5xxl_max_token_length,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
@@ -331,7 +337,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)

View File

@@ -43,6 +43,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
@@ -188,8 +192,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
max_token_length=args.t5xxl_max_token_length,
apply_lg_attn_mask=args.apply_lg_attn_mask,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None

View File

@@ -273,7 +273,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -322,7 +322,11 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -209,7 +209,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -223,7 +223,11 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -181,7 +181,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -195,7 +195,11 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -83,7 +83,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
else:
return None

View File

@@ -150,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# cache latents with dataset
# TODO use DataLoader to speed up
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents to disk.")

View File

@@ -157,7 +157,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -559,9 +559,9 @@ class NetworkTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
if val_dataset_group is not None:
val_dataset_group.new_cache_latents(vae, accelerator)
val_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -382,7 +382,7 @@ class TextualInversionTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()