This commit is contained in:
Kohya S.
2026-03-29 18:47:19 +09:00
committed by GitHub
11 changed files with 932 additions and 267 deletions

View File

@@ -155,6 +155,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
"""
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors"
def __init__(
self,
@@ -166,7 +167,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk:
@@ -177,17 +179,34 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "prompt_embeds"):
return False
if "attn_mask" not in keys:
return False
if "t5_input_ids" not in keys:
return False
if "t5_attn_mask" not in keys:
return False
if "caption_dropout_rate" not in keys:
return False
else:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -195,6 +214,19 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy()
attn_mask = f.get_tensor("attn_mask").numpy()
t5_input_ids = f.get_tensor("t5_input_ids").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy()
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
data = np.load(npz_path)
prompt_embeds = data["prompt_embeds"]
attn_mask = data["attn_mask"]
@@ -219,32 +251,75 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, tokens_and_masks
)
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos)
else:
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
prompt_embeds = prompt_embeds.cpu()
attn_mask = attn_mask.cpu()
t5_input_ids = t5_input_ids.cpu().to(torch.int32)
t5_attn_mask = t5_attn_mask.cpu().to(torch.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
pe = prompt_embeds[i]
tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe
tensors["attn_mask"] = attn_mask[i]
tensors["t5_input_ids"] = t5_input_ids[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
metadata = {
"architecture": "anima",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
info.text_encoder_outputs = (
prompt_embeds[i].numpy(),
attn_mask[i].numpy(),
t5_input_ids[i].numpy(),
t5_attn_mask[i].numpy(),
caption_dropout_rate,
)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -255,16 +330,20 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
"""
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.ANIMA_LATENTS_NPZ_SUFFIX
return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "anima"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -2,7 +2,7 @@
import os
import re
from typing import Any, List, Optional, Tuple, Union, Callable
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@@ -19,6 +19,48 @@ import logging
logger = logging.getLogger(__name__)
LATENTS_CACHE_FORMAT_VERSION = "1.0.1"
TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1"
# global cache format setting: "npz" or "safetensors"
_cache_format: str = "npz"
def set_cache_format(cache_format: str) -> None:
global _cache_format
_cache_format = cache_format
def get_cache_format() -> str:
return _cache_format
_TORCH_DTYPE_TO_STR = {
torch.float64: "float64",
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
torch.int16: "int16",
torch.int8: "int8",
torch.uint8: "uint8",
torch.bool: "bool",
}
_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16}
def _dtype_to_str(dtype: torch.dtype) -> str:
return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", ""))
def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]:
"""Find a tensor key that starts with the given prefix. Returns the first match or None."""
for key in tensors_keys:
if key.startswith(prefix) or key == prefix:
return key
return None
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -362,6 +404,10 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self):
return self._is_weighted
@property
def cache_format(self) -> str:
return get_cache_format()
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
@@ -407,6 +453,10 @@ class LatentsCachingStrategy:
def batch_size(self):
return self._batch_size
@property
def cache_format(self) -> str:
return get_cache_format()
@property
def cache_suffix(self):
raise NotImplementedError
@@ -439,7 +489,7 @@ class LatentsCachingStrategy:
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz file
npz_path: path to the npz/safetensors file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
@@ -454,6 +504,11 @@ class LatentsCachingStrategy:
if self.skip_disk_cache_validity_check:
return True
if npz_path.endswith(".safetensors"):
return self._is_disk_cached_latents_expected_safetensors(
latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution
)
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
@@ -476,6 +531,40 @@ class LatentsCachingStrategy:
return True
def _is_disk_cached_latents_expected_safetensors(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
st_path: str,
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
) -> bool:
from library.safetensors_utils import MemoryEfficientSafeOpen
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W)
reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x"
try:
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_prefix = f"latents_{reso_tag}"
if not any(k.startswith(latents_prefix) for k in keys):
return False
if flip_aug:
flipped_prefix = f"latents_flipped_{reso_tag}"
if not any(k.startswith(flipped_prefix) for k in keys):
return False
if apply_alpha_mask:
mask_prefix = f"alpha_mask_{reso_tag}"
if not any(k.startswith(mask_prefix) for k in keys):
return False
except Exception as e:
logger.error(f"Error loading file: {st_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
@@ -571,7 +660,7 @@ class LatentsCachingStrategy:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
npz_path (str): Path to the npz/safetensors file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
@@ -583,6 +672,9 @@ class LatentsCachingStrategy:
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if npz_path.endswith(".safetensors"):
return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso)
if latents_stride is None:
key_reso_suffix = ""
else:
@@ -609,6 +701,39 @@ class LatentsCachingStrategy:
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def _load_latents_from_disk_safetensors(
self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
from library.safetensors_utils import MemoryEfficientSafeOpen
if latents_stride is None:
reso_tag = "1x"
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride)
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}")
if latents_key is None:
raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}")
latents = f.get_tensor(latents_key).numpy()
original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}")
original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0]
crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}")
crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0]
flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}")
flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None
alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}")
alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
@@ -621,17 +746,23 @@ class LatentsCachingStrategy:
):
"""
Args:
npz_path (str): Path to the npz file.
npz_path (str): Path to the npz/safetensors file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix
key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz)
Returns:
None
"""
if npz_path.endswith(".safetensors"):
self._save_latents_to_disk_safetensors(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix
)
return
kwargs = {}
if os.path.exists(npz_path):
@@ -640,7 +771,7 @@ class LatentsCachingStrategy:
for key in npz.files:
kwargs[key] = npz[key]
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
# float() is needed because npz doesn't support bfloat16
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
@@ -649,3 +780,59 @@ class LatentsCachingStrategy:
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)
def _save_latents_to_disk_safetensors(
self,
st_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
latents_tensor = latents_tensor.cpu()
latents_size = latents_tensor.shape[-2:] # H, W
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
dtype_str = _dtype_to_str(latents_tensor.dtype)
# NaN check and zero replacement
if torch.isnan(latents_tensor).any():
latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0)
tensors: Dict[str, torch.Tensor] = {}
# load existing file and merge (for multi-resolution)
if os.path.exists(st_path):
with MemoryEfficientSafeOpen(st_path) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor
tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32)
tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32)
if flipped_latents_tensor is not None:
flipped_latents_tensor = flipped_latents_tensor.cpu()
if torch.isnan(flipped_latents_tensor).any():
flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0)
tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor
if alpha_mask is not None:
alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask)
tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor
metadata = {
"architecture": self._get_architecture_name(),
"width": str(latents_size[1]),
"height": str(latents_size[0]),
"format_version": LATENTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, st_path, metadata=metadata)
def _get_architecture_name(self) -> str:
"""Override in subclasses to return the architecture name for safetensors metadata."""
return "unknown"

View File

@@ -87,6 +87,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_flux_te.safetensors"
def __init__(
self,
@@ -102,7 +103,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -113,20 +115,40 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "l_pooled"):
return False
if not _find_tensor_by_prefix(keys, "t5_out"):
return False
if not _find_tensor_by_prefix(keys, "txt_ids"):
return False
if "t5_attn_mask" not in keys:
return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -134,6 +156,18 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
l_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "l_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
txt_ids = f.get_tensor(_find_tensor_by_prefix(keys, "txt_ids")).numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
@@ -161,56 +195,100 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
t5_attn_mask_tokens = tokens_and_masks[2]
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos)
else:
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(self, l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
lp = l_pooled[i]
to = t5_out[i]
ti = txt_ids[i]
tensors[f"l_pooled_{_dtype_to_str(lp.dtype)}"] = lp
tensors[f"t5_out_{_dtype_to_str(to.dtype)}"] = to
tensors[f"txt_ids_{_dtype_to_str(ti.dtype)}"] = ti
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "flux",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
info.text_encoder_outputs = (l_pooled[i].numpy(), t5_out[i].numpy(), txt_ids[i].numpy(), t5_attn_mask[i].numpy())
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
FLUX_LATENTS_ST_SUFFIX = "_flux.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
return self.FLUX_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "flux"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -81,16 +81,17 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_hi_te.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return (
os.path.splitext(image_abs_path)[0]
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
os.path.splitext(image_abs_path)[0] + suffix
)
def is_disk_cached_outputs_expected(self, npz_path: str):
@@ -102,17 +103,34 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True
try:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "vlm_embed"):
return False
if "vlm_mask" not in keys:
return False
if not _find_tensor_by_prefix(keys, "byt5_embed"):
return False
if "byt5_mask" not in keys:
return False
if "ocr_mask" not in keys:
return False
else:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -120,6 +138,19 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
vlm_embed = f.get_tensor(_find_tensor_by_prefix(keys, "vlm_embed")).numpy()
vlm_mask = f.get_tensor("vlm_mask").numpy()
byt5_embed = f.get_tensor(_find_tensor_by_prefix(keys, "byt5_embed")).numpy()
byt5_mask = f.get_tensor("byt5_mask").numpy()
ocr_mask = f.get_tensor("ocr_mask").numpy()
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
data = np.load(npz_path)
vln_embed = data["vlm_embed"]
vlm_mask = data["vlm_mask"]
@@ -140,54 +171,102 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
tokenize_strategy, models, tokens_and_masks
)
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos)
else:
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = ocr_mask.cpu().numpy()
vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = ocr_mask.cpu().numpy()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
def _cache_batch_outputs_safetensors(self, vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
vlm_embed = vlm_embed.cpu()
vlm_mask = vlm_mask.cpu()
byt5_embed = byt5_embed.cpu()
byt5_mask = byt5_mask.cpu()
ocr_mask = ocr_mask.cpu()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
ve = vlm_embed[i]
be = byt5_embed[i]
tensors[f"vlm_embed_{_dtype_to_str(ve.dtype)}"] = ve
tensors["vlm_mask"] = vlm_mask[i]
tensors[f"byt5_embed_{_dtype_to_str(be.dtype)}"] = be
tensors["byt5_mask"] = byt5_mask[i]
tensors["ocr_mask"] = ocr_mask[i]
metadata = {
"architecture": "hunyuan_image",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
info.text_encoder_outputs = (
vlm_embed[i].numpy(),
vlm_mask[i].numpy(),
byt5_embed[i].numpy(),
byt5_mask[i].numpy(),
ocr_mask[i].numpy(),
)
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
HUNYUAN_IMAGE_LATENTS_ST_SUFFIX = "_hi.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
return self.HUNYUAN_IMAGE_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "hunyuan_image"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -146,6 +146,7 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors"
def __init__(
self,
@@ -162,19 +163,10 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.
Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -183,13 +175,26 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return True
try:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state"):
return False
if "attention_mask" not in keys:
return False
if "input_ids" not in keys:
return False
else:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -198,11 +203,22 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz file
Load outputs from a npz/safetensors file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy()
attention_mask = f.get_tensor("attention_mask").numpy()
input_ids = f.get_tensor("input_ids").numpy()
return [hidden_state, input_ids, attention_mask]
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
@@ -217,16 +233,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of ImageInfo
Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
@@ -252,37 +258,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
)
)
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch)
else:
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy()
input_ids_np = input_ids.cpu().numpy()
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids_np[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
]
def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state = hidden_state.cpu()
input_ids = input_ids.cpu()
attention_mask = attention_masks.cpu()
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs = hidden_state[i]
tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs
tensors["attention_mask"] = attention_mask[i]
tensors["input_ids"] = input_ids[i]
metadata = {
"architecture": "lumina",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
hidden_state[i].numpy(),
input_ids[i].numpy(),
attention_mask[i].numpy(),
]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
@@ -291,7 +335,7 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
@property
def cache_suffix(self) -> str:
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
@@ -299,9 +343,12 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "lumina"
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],

View File

@@ -138,24 +138,32 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
SD_LATENTS_ST_SUFFIX = "_sd.safetensors"
SDXL_LATENTS_ST_SUFFIX = "_sdxl.safetensors"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
def __init__(
self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
if self.cache_format == "safetensors":
return self.SD_LATENTS_ST_SUFFIX if self.sd else self.SDXL_LATENTS_ST_SUFFIX
else:
return self.SD_LATENTS_NPZ_SUFFIX if self.sd else self.SDXL_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
if self.cache_format != "safetensors":
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "sd" if self.sd else "sdxl"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -255,6 +255,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_sd3_te.safetensors"
def __init__(
self,
@@ -270,7 +271,8 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -281,27 +283,54 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "lg_out"):
return False
if not _find_tensor_by_prefix(keys, "lg_pooled"):
return False
if "clip_l_attn_mask" not in keys or "clip_g_attn_mask" not in keys:
return False
if not _find_tensor_by_prefix(keys, "t5_out"):
return False
if "t5_attn_mask" not in keys:
return False
if "apply_lg_attn_mask" not in keys:
return False
apply_lg = f.get_tensor("apply_lg_attn_mask").item()
if bool(apply_lg) != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz:
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -309,6 +338,20 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
lg_out = f.get_tensor(_find_tensor_by_prefix(keys, "lg_out")).numpy()
lg_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "lg_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
l_attn_mask = f.get_tensor("clip_l_attn_mask").numpy()
g_attn_mask = f.get_tensor("clip_g_attn_mask").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
@@ -339,65 +382,127 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
enable_dropout=False,
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
l_attn_mask_tokens = tokens_and_masks[3]
g_attn_mask_tokens = tokens_and_masks[4]
t5_attn_mask_tokens = tokens_and_masks[5]
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(
lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
)
else:
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = l_attn_mask_tokens.cpu().numpy()
g_attn_mask = g_attn_mask_tokens.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(
self, lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
lg_out = lg_out.cpu()
t5_out = t5_out.cpu()
lg_pooled = lg_pooled.cpu()
l_attn_mask = l_attn_mask_tokens.cpu()
g_attn_mask = g_attn_mask_tokens.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
tensors[f"lg_out_{_dtype_to_str(lg_out_i.dtype)}"] = lg_out_i
tensors[f"t5_out_{_dtype_to_str(t5_out_i.dtype)}"] = t5_out_i
tensors[f"lg_pooled_{_dtype_to_str(lg_pooled_i.dtype)}"] = lg_pooled_i
tensors["clip_l_attn_mask"] = l_attn_mask[i]
tensors["clip_g_attn_mask"] = g_attn_mask[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_lg_attn_mask"] = torch.tensor(self.apply_lg_attn_mask, dtype=torch.bool)
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "sd3",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
info.text_encoder_outputs = (
lg_out[i].numpy(),
t5_out[i].numpy(),
lg_pooled[i].numpy(),
l_attn_mask[i].numpy(),
g_attn_mask[i].numpy(),
t5_attn_mask[i].numpy(),
)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
SD3_LATENTS_ST_SUFFIX = "_sd3.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
return self.SD3_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "sd3"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -221,6 +221,7 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_te_outputs.safetensors"
def __init__(
self,
@@ -233,7 +234,8 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -244,9 +246,22 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state1"):
return False
if not _find_tensor_by_prefix(keys, "hidden_state2"):
return False
if not _find_tensor_by_prefix(keys, "pool2"):
return False
else:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -254,6 +269,17 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state1 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state1")).numpy()
hidden_state2 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state2")).numpy()
pool2 = f.get_tensor(_find_tensor_by_prefix(keys, "pool2")).numpy()
return [hidden_state1, hidden_state2, pool2]
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
@@ -279,28 +305,68 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state1, hidden_state2, pool2, infos)
else:
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
def _cache_batch_outputs_safetensors(self, hidden_state1, hidden_state2, pool2, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
tensors = {}
# merge existing file if partial
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs1 = hidden_state1[i]
hs2 = hidden_state2[i]
p2 = pool2[i]
tensors[f"hidden_state1_{_dtype_to_str(hs1.dtype)}"] = hs1
tensors[f"hidden_state2_{_dtype_to_str(hs2.dtype)}"] = hs2
tensors[f"pool2_{_dtype_to_str(p2.dtype)}"] = p2
metadata = {
"architecture": "sdxl",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
info.text_encoder_outputs = [
hidden_state1[i].numpy(),
hidden_state2[i].numpy(),
pool2[i].numpy(),
]

View File

@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
return all(
[
not (
subset.caption_dropout_rate > 0 and not cache_supports_dropout
subset.caption_dropout_rate > 0
and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -4471,7 +4472,10 @@ def verify_training_args(args: argparse.Namespace):
Verify training arguments. Also reflect highvram option to global variable
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
"""
from library.strategy_base import set_cache_format
enable_high_vram(args)
set_cache_format(args.cache_format)
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -4637,6 +4641,14 @@ def add_dataset_arguments(
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument(
"--cache_format",
type=str,
default="npz",
choices=["npz", "safetensors"],
help="format for latent and text encoder output caches (default: npz). safetensors saves in native dtype (e.g. bf16) for smaller files and faster I/O"
" / latentおよびtext encoder出力キャッシュの保存形式デフォルト: npz。safetensorsはネイティブdtype例: bf16で保存し、ファイルサイズ削減と高速化が可能",
)
parser.add_argument(
"--enable_bucket",
action="store_true",

View File

@@ -69,6 +69,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
strategy_base.set_cache_format(args.cache_format)
if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
else:

View File

@@ -156,6 +156,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
text_encoder.eval()
# build text encoder outputs caching strategy
strategy_base.set_cache_format(args.cache_format)
if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions