mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Merge a437949d47 into 5f793fb0f4
This commit is contained in:
@@ -155,6 +155,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
"""
|
||||
|
||||
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
|
||||
ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -166,7 +167,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
return os.path.splitext(image_abs_path)[0] + suffix
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
if not self.cache_to_disk:
|
||||
@@ -177,17 +179,34 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "prompt_embeds" not in npz:
|
||||
return False
|
||||
if "attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_input_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "caption_dropout_rate" not in npz:
|
||||
return False
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
if not _find_tensor_by_prefix(keys, "prompt_embeds"):
|
||||
return False
|
||||
if "attn_mask" not in keys:
|
||||
return False
|
||||
if "t5_input_ids" not in keys:
|
||||
return False
|
||||
if "t5_attn_mask" not in keys:
|
||||
return False
|
||||
if "caption_dropout_rate" not in keys:
|
||||
return False
|
||||
else:
|
||||
npz = np.load(npz_path)
|
||||
if "prompt_embeds" not in npz:
|
||||
return False
|
||||
if "attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_input_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "caption_dropout_rate" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -195,6 +214,19 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy()
|
||||
attn_mask = f.get_tensor("attn_mask").numpy()
|
||||
t5_input_ids = f.get_tensor("t5_input_ids").numpy()
|
||||
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
|
||||
caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy()
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
|
||||
|
||||
data = np.load(npz_path)
|
||||
prompt_embeds = data["prompt_embeds"]
|
||||
attn_mask = data["attn_mask"]
|
||||
@@ -219,32 +251,75 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy, models, tokens_and_masks
|
||||
)
|
||||
|
||||
# Convert to numpy for caching
|
||||
if prompt_embeds.dtype == torch.bfloat16:
|
||||
prompt_embeds = prompt_embeds.float()
|
||||
prompt_embeds = prompt_embeds.cpu().numpy()
|
||||
attn_mask = attn_mask.cpu().numpy()
|
||||
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||
if self.cache_format == "safetensors":
|
||||
self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos)
|
||||
else:
|
||||
# Convert to numpy for caching
|
||||
if prompt_embeds.dtype == torch.bfloat16:
|
||||
prompt_embeds = prompt_embeds.float()
|
||||
prompt_embeds = prompt_embeds.cpu().numpy()
|
||||
attn_mask = attn_mask.cpu().numpy()
|
||||
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
prompt_embeds_i = prompt_embeds[i]
|
||||
attn_mask_i = attn_mask[i]
|
||||
t5_input_ids_i = t5_input_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
prompt_embeds=prompt_embeds_i,
|
||||
attn_mask=attn_mask_i,
|
||||
t5_input_ids=t5_input_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
|
||||
|
||||
def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
|
||||
|
||||
prompt_embeds = prompt_embeds.cpu()
|
||||
attn_mask = attn_mask.cpu()
|
||||
t5_input_ids = t5_input_ids.cpu().to(torch.int32)
|
||||
t5_attn_mask = t5_attn_mask.cpu().to(torch.int32)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
prompt_embeds_i = prompt_embeds[i]
|
||||
attn_mask_i = attn_mask[i]
|
||||
t5_input_ids_i = t5_input_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
prompt_embeds=prompt_embeds_i,
|
||||
attn_mask=attn_mask_i,
|
||||
t5_input_ids=t5_input_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
tensors = {}
|
||||
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
|
||||
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
pe = prompt_embeds[i]
|
||||
tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe
|
||||
tensors["attn_mask"] = attn_mask[i]
|
||||
tensors["t5_input_ids"] = t5_input_ids[i]
|
||||
tensors["t5_attn_mask"] = t5_attn_mask[i]
|
||||
tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
metadata = {
|
||||
"architecture": "anima",
|
||||
"caption1": info.caption,
|
||||
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
|
||||
else:
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
info.text_encoder_outputs = (
|
||||
prompt_embeds[i].numpy(),
|
||||
attn_mask[i].numpy(),
|
||||
t5_input_ids[i].numpy(),
|
||||
t5_attn_mask[i].numpy(),
|
||||
caption_dropout_rate,
|
||||
)
|
||||
|
||||
|
||||
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -255,16 +330,20 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
"""
|
||||
|
||||
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
|
||||
ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
return "anima"
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union, Callable
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -19,6 +19,48 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LATENTS_CACHE_FORMAT_VERSION = "1.0.1"
|
||||
TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1"
|
||||
|
||||
# global cache format setting: "npz" or "safetensors"
|
||||
_cache_format: str = "npz"
|
||||
|
||||
|
||||
def set_cache_format(cache_format: str) -> None:
|
||||
global _cache_format
|
||||
_cache_format = cache_format
|
||||
|
||||
|
||||
def get_cache_format() -> str:
|
||||
return _cache_format
|
||||
|
||||
_TORCH_DTYPE_TO_STR = {
|
||||
torch.float64: "float64",
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.int64: "int64",
|
||||
torch.int32: "int32",
|
||||
torch.int16: "int16",
|
||||
torch.int8: "int8",
|
||||
torch.uint8: "uint8",
|
||||
torch.bool: "bool",
|
||||
}
|
||||
|
||||
_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16}
|
||||
|
||||
|
||||
def _dtype_to_str(dtype: torch.dtype) -> str:
|
||||
return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", ""))
|
||||
|
||||
|
||||
def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]:
|
||||
"""Find a tensor key that starts with the given prefix. Returns the first match or None."""
|
||||
for key in tensors_keys:
|
||||
if key.startswith(prefix) or key == prefix:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
@@ -362,6 +404,10 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def is_weighted(self):
|
||||
return self._is_weighted
|
||||
|
||||
@property
|
||||
def cache_format(self) -> str:
|
||||
return get_cache_format()
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -407,6 +453,10 @@ class LatentsCachingStrategy:
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def cache_format(self) -> str:
|
||||
return get_cache_format()
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
raise NotImplementedError
|
||||
@@ -439,7 +489,7 @@ class LatentsCachingStrategy:
|
||||
Args:
|
||||
latents_stride: stride of latents
|
||||
bucket_reso: resolution of the bucket
|
||||
npz_path: path to the npz file
|
||||
npz_path: path to the npz/safetensors file
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
@@ -454,6 +504,11 @@ class LatentsCachingStrategy:
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
if npz_path.endswith(".safetensors"):
|
||||
return self._is_disk_cached_latents_expected_safetensors(
|
||||
latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution
|
||||
)
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
# e.g. "_32x64", HxW
|
||||
@@ -476,6 +531,40 @@ class LatentsCachingStrategy:
|
||||
|
||||
return True
|
||||
|
||||
def _is_disk_cached_latents_expected_safetensors(
|
||||
self,
|
||||
latents_stride: int,
|
||||
bucket_reso: Tuple[int, int],
|
||||
st_path: str,
|
||||
flip_aug: bool,
|
||||
apply_alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
) -> bool:
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W)
|
||||
reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x"
|
||||
|
||||
try:
|
||||
with MemoryEfficientSafeOpen(st_path) as f:
|
||||
keys = f.keys()
|
||||
latents_prefix = f"latents_{reso_tag}"
|
||||
if not any(k.startswith(latents_prefix) for k in keys):
|
||||
return False
|
||||
if flip_aug:
|
||||
flipped_prefix = f"latents_flipped_{reso_tag}"
|
||||
if not any(k.startswith(flipped_prefix) for k in keys):
|
||||
return False
|
||||
if apply_alpha_mask:
|
||||
mask_prefix = f"alpha_mask_{reso_tag}"
|
||||
if not any(k.startswith(mask_prefix) for k in keys):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {st_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self,
|
||||
@@ -571,7 +660,7 @@ class LatentsCachingStrategy:
|
||||
"""
|
||||
Args:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
npz_path (str): Path to the npz/safetensors file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
@@ -583,6 +672,9 @@ class LatentsCachingStrategy:
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if npz_path.endswith(".safetensors"):
|
||||
return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso)
|
||||
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
@@ -609,6 +701,39 @@ class LatentsCachingStrategy:
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def _load_latents_from_disk_safetensors(
|
||||
self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
|
||||
if latents_stride is None:
|
||||
reso_tag = "1x"
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride)
|
||||
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
|
||||
|
||||
with MemoryEfficientSafeOpen(st_path) as f:
|
||||
keys = f.keys()
|
||||
|
||||
latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}")
|
||||
if latents_key is None:
|
||||
raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}")
|
||||
latents = f.get_tensor(latents_key).numpy()
|
||||
|
||||
original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}")
|
||||
original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0]
|
||||
|
||||
crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}")
|
||||
crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0]
|
||||
|
||||
flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}")
|
||||
flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None
|
||||
|
||||
alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}")
|
||||
alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None
|
||||
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self,
|
||||
npz_path,
|
||||
@@ -621,17 +746,23 @@ class LatentsCachingStrategy:
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
npz_path (str): Path to the npz/safetensors file.
|
||||
latents_tensor (torch.Tensor): Latent tensor
|
||||
original_size (List[int]): Original size of the image
|
||||
crop_ltrb (List[int]): Crop left top right bottom
|
||||
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
|
||||
alpha_mask (Optional[torch.Tensor]): Alpha mask
|
||||
key_reso_suffix (str): Key resolution suffix
|
||||
key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz)
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if npz_path.endswith(".safetensors"):
|
||||
self._save_latents_to_disk_safetensors(
|
||||
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix
|
||||
)
|
||||
return
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
@@ -640,7 +771,7 @@ class LatentsCachingStrategy:
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
|
||||
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
|
||||
# float() is needed because npz doesn't support bfloat16
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
@@ -649,3 +780,59 @@ class LatentsCachingStrategy:
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(npz_path, **kwargs)
|
||||
|
||||
def _save_latents_to_disk_safetensors(
|
||||
self,
|
||||
st_path,
|
||||
latents_tensor,
|
||||
original_size,
|
||||
crop_ltrb,
|
||||
flipped_latents_tensor=None,
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
|
||||
latents_tensor = latents_tensor.cpu()
|
||||
latents_size = latents_tensor.shape[-2:] # H, W
|
||||
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
|
||||
dtype_str = _dtype_to_str(latents_tensor.dtype)
|
||||
|
||||
# NaN check and zero replacement
|
||||
if torch.isnan(latents_tensor).any():
|
||||
latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0)
|
||||
|
||||
tensors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# load existing file and merge (for multi-resolution)
|
||||
if os.path.exists(st_path):
|
||||
with MemoryEfficientSafeOpen(st_path) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor
|
||||
tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32)
|
||||
tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32)
|
||||
|
||||
if flipped_latents_tensor is not None:
|
||||
flipped_latents_tensor = flipped_latents_tensor.cpu()
|
||||
if torch.isnan(flipped_latents_tensor).any():
|
||||
flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0)
|
||||
tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor
|
||||
|
||||
if alpha_mask is not None:
|
||||
alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask)
|
||||
tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor
|
||||
|
||||
metadata = {
|
||||
"architecture": self._get_architecture_name(),
|
||||
"width": str(latents_size[1]),
|
||||
"height": str(latents_size[0]),
|
||||
"format_version": LATENTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
|
||||
mem_eff_save_file(tensors, st_path, metadata=metadata)
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
"""Override in subclasses to return the architecture name for safetensors metadata."""
|
||||
return "unknown"
|
||||
|
||||
@@ -87,6 +87,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_flux_te.safetensors"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -102,7 +103,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
suffix = self.FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
return os.path.splitext(image_abs_path)[0] + suffix
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
@@ -113,20 +115,40 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
if not _find_tensor_by_prefix(keys, "l_pooled"):
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "t5_out"):
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "txt_ids"):
|
||||
return False
|
||||
if "t5_attn_mask" not in keys:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in keys:
|
||||
return False
|
||||
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
|
||||
if bool(apply_t5) != self.apply_t5_attn_mask:
|
||||
return False
|
||||
else:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -134,6 +156,18 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
l_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "l_pooled")).numpy()
|
||||
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
|
||||
txt_ids = f.get_tensor(_find_tensor_by_prefix(keys, "txt_ids")).numpy()
|
||||
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
data = np.load(npz_path)
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
@@ -161,56 +195,100 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
t5_attn_mask_tokens = tokens_and_masks[2]
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
if self.cache_format == "safetensors":
|
||||
self._cache_batch_outputs_safetensors(l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos)
|
||||
else:
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
|
||||
def _cache_batch_outputs_safetensors(self, l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
|
||||
|
||||
l_pooled = l_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
txt_ids = txt_ids.cpu()
|
||||
t5_attn_mask = t5_attn_mask_tokens.cpu()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
tensors = {}
|
||||
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
|
||||
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
lp = l_pooled[i]
|
||||
to = t5_out[i]
|
||||
ti = txt_ids[i]
|
||||
tensors[f"l_pooled_{_dtype_to_str(lp.dtype)}"] = lp
|
||||
tensors[f"t5_out_{_dtype_to_str(to.dtype)}"] = to
|
||||
tensors[f"txt_ids_{_dtype_to_str(ti.dtype)}"] = ti
|
||||
tensors["t5_attn_mask"] = t5_attn_mask[i]
|
||||
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
|
||||
|
||||
metadata = {
|
||||
"architecture": "flux",
|
||||
"caption1": info.caption,
|
||||
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
info.text_encoder_outputs = (l_pooled[i].numpy(), t5_out[i].numpy(), txt_ids[i].numpy(), t5_attn_mask[i].numpy())
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
||||
FLUX_LATENTS_ST_SUFFIX = "_flux.safetensors"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
return self.FLUX_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
+ self.cache_suffix
|
||||
)
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
return "flux"
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
|
||||
@@ -81,16 +81,17 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
|
||||
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_hi_te.safetensors"
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
suffix = self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
return (
|
||||
os.path.splitext(image_abs_path)[0]
|
||||
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
os.path.splitext(image_abs_path)[0] + suffix
|
||||
)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
@@ -102,17 +103,34 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "vlm_embed" not in npz:
|
||||
return False
|
||||
if "vlm_mask" not in npz:
|
||||
return False
|
||||
if "byt5_embed" not in npz:
|
||||
return False
|
||||
if "byt5_mask" not in npz:
|
||||
return False
|
||||
if "ocr_mask" not in npz:
|
||||
return False
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
if not _find_tensor_by_prefix(keys, "vlm_embed"):
|
||||
return False
|
||||
if "vlm_mask" not in keys:
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "byt5_embed"):
|
||||
return False
|
||||
if "byt5_mask" not in keys:
|
||||
return False
|
||||
if "ocr_mask" not in keys:
|
||||
return False
|
||||
else:
|
||||
npz = np.load(npz_path)
|
||||
if "vlm_embed" not in npz:
|
||||
return False
|
||||
if "vlm_mask" not in npz:
|
||||
return False
|
||||
if "byt5_embed" not in npz:
|
||||
return False
|
||||
if "byt5_mask" not in npz:
|
||||
return False
|
||||
if "ocr_mask" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -120,6 +138,19 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
vlm_embed = f.get_tensor(_find_tensor_by_prefix(keys, "vlm_embed")).numpy()
|
||||
vlm_mask = f.get_tensor("vlm_mask").numpy()
|
||||
byt5_embed = f.get_tensor(_find_tensor_by_prefix(keys, "byt5_embed")).numpy()
|
||||
byt5_mask = f.get_tensor("byt5_mask").numpy()
|
||||
ocr_mask = f.get_tensor("ocr_mask").numpy()
|
||||
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
|
||||
|
||||
data = np.load(npz_path)
|
||||
vln_embed = data["vlm_embed"]
|
||||
vlm_mask = data["vlm_mask"]
|
||||
@@ -140,54 +171,102 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
|
||||
tokenize_strategy, models, tokens_and_masks
|
||||
)
|
||||
|
||||
if vlm_embed.dtype == torch.bfloat16:
|
||||
vlm_embed = vlm_embed.float()
|
||||
if byt5_embed.dtype == torch.bfloat16:
|
||||
byt5_embed = byt5_embed.float()
|
||||
if self.cache_format == "safetensors":
|
||||
self._cache_batch_outputs_safetensors(vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos)
|
||||
else:
|
||||
if vlm_embed.dtype == torch.bfloat16:
|
||||
vlm_embed = vlm_embed.float()
|
||||
if byt5_embed.dtype == torch.bfloat16:
|
||||
byt5_embed = byt5_embed.float()
|
||||
|
||||
vlm_embed = vlm_embed.cpu().numpy()
|
||||
vlm_mask = vlm_mask.cpu().numpy()
|
||||
byt5_embed = byt5_embed.cpu().numpy()
|
||||
byt5_mask = byt5_mask.cpu().numpy()
|
||||
ocr_mask = ocr_mask.cpu().numpy()
|
||||
vlm_embed = vlm_embed.cpu().numpy()
|
||||
vlm_mask = vlm_mask.cpu().numpy()
|
||||
byt5_embed = byt5_embed.cpu().numpy()
|
||||
byt5_mask = byt5_mask.cpu().numpy()
|
||||
ocr_mask = ocr_mask.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
vlm_embed_i = vlm_embed[i]
|
||||
vlm_mask_i = vlm_mask[i]
|
||||
byt5_embed_i = byt5_embed[i]
|
||||
byt5_mask_i = byt5_mask[i]
|
||||
ocr_mask_i = ocr_mask[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
vlm_embed=vlm_embed_i,
|
||||
vlm_mask=vlm_mask_i,
|
||||
byt5_embed=byt5_embed_i,
|
||||
byt5_mask=byt5_mask_i,
|
||||
ocr_mask=ocr_mask_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
|
||||
|
||||
def _cache_batch_outputs_safetensors(self, vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
|
||||
|
||||
vlm_embed = vlm_embed.cpu()
|
||||
vlm_mask = vlm_mask.cpu()
|
||||
byt5_embed = byt5_embed.cpu()
|
||||
byt5_mask = byt5_mask.cpu()
|
||||
ocr_mask = ocr_mask.cpu()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
vlm_embed_i = vlm_embed[i]
|
||||
vlm_mask_i = vlm_mask[i]
|
||||
byt5_embed_i = byt5_embed[i]
|
||||
byt5_mask_i = byt5_mask[i]
|
||||
ocr_mask_i = ocr_mask[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
vlm_embed=vlm_embed_i,
|
||||
vlm_mask=vlm_mask_i,
|
||||
byt5_embed=byt5_embed_i,
|
||||
byt5_mask=byt5_mask_i,
|
||||
ocr_mask=ocr_mask_i,
|
||||
)
|
||||
tensors = {}
|
||||
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
|
||||
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
ve = vlm_embed[i]
|
||||
be = byt5_embed[i]
|
||||
tensors[f"vlm_embed_{_dtype_to_str(ve.dtype)}"] = ve
|
||||
tensors["vlm_mask"] = vlm_mask[i]
|
||||
tensors[f"byt5_embed_{_dtype_to_str(be.dtype)}"] = be
|
||||
tensors["byt5_mask"] = byt5_mask[i]
|
||||
tensors["ocr_mask"] = ocr_mask[i]
|
||||
|
||||
metadata = {
|
||||
"architecture": "hunyuan_image",
|
||||
"caption1": info.caption,
|
||||
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
|
||||
else:
|
||||
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
|
||||
info.text_encoder_outputs = (
|
||||
vlm_embed[i].numpy(),
|
||||
vlm_mask[i].numpy(),
|
||||
byt5_embed[i].numpy(),
|
||||
byt5_mask[i].numpy(),
|
||||
ocr_mask[i].numpy(),
|
||||
)
|
||||
|
||||
|
||||
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
|
||||
HUNYUAN_IMAGE_LATENTS_ST_SUFFIX = "_hi.safetensors"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
|
||||
return self.HUNYUAN_IMAGE_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
|
||||
+ self.cache_suffix
|
||||
)
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
return "hunyuan_image"
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
|
||||
@@ -146,6 +146,7 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
|
||||
LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -162,19 +163,10 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return (
|
||||
os.path.splitext(image_abs_path)[0]
|
||||
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
)
|
||||
suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
return os.path.splitext(image_abs_path)[0] + suffix
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
|
||||
Returns:
|
||||
bool: True if the npz file is expected to be cached.
|
||||
"""
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
@@ -183,13 +175,26 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state" not in npz:
|
||||
return False
|
||||
if "attention_mask" not in npz:
|
||||
return False
|
||||
if "input_ids" not in npz:
|
||||
return False
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
if not _find_tensor_by_prefix(keys, "hidden_state"):
|
||||
return False
|
||||
if "attention_mask" not in keys:
|
||||
return False
|
||||
if "input_ids" not in keys:
|
||||
return False
|
||||
else:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state" not in npz:
|
||||
return False
|
||||
if "attention_mask" not in npz:
|
||||
return False
|
||||
if "input_ids" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -198,11 +203,22 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Load outputs from a npz file
|
||||
Load outputs from a npz/safetensors file
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: hidden_state, input_ids, attention_mask
|
||||
"""
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy()
|
||||
attention_mask = f.get_tensor("attention_mask").numpy()
|
||||
input_ids = f.get_tensor("input_ids").numpy()
|
||||
return [hidden_state, input_ids, attention_mask]
|
||||
|
||||
data = np.load(npz_path)
|
||||
hidden_state = data["hidden_state"]
|
||||
attention_mask = data["attention_mask"]
|
||||
@@ -217,16 +233,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: List[train_util.ImageInfo],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
text_encoding_strategy (LuminaTextEncodingStrategy):
|
||||
infos (List): List of ImageInfo
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
|
||||
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
|
||||
|
||||
@@ -252,37 +258,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
)
|
||||
)
|
||||
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
if self.cache_format == "safetensors":
|
||||
self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch)
|
||||
else:
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
|
||||
hidden_state = hidden_state.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy() # (B, S)
|
||||
input_ids = input_ids.cpu().numpy() # (B, S)
|
||||
hidden_state = hidden_state.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy()
|
||||
input_ids_np = input_ids.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(batch):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids_np[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [
|
||||
hidden_state_i,
|
||||
input_ids_i,
|
||||
attention_mask_i,
|
||||
]
|
||||
|
||||
def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
|
||||
|
||||
hidden_state = hidden_state.cpu()
|
||||
input_ids = input_ids.cpu()
|
||||
attention_mask = attention_masks.cpu()
|
||||
|
||||
for i, info in enumerate(batch):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
)
|
||||
tensors = {}
|
||||
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
|
||||
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
hs = hidden_state[i]
|
||||
tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs
|
||||
tensors["attention_mask"] = attention_mask[i]
|
||||
tensors["input_ids"] = input_ids[i]
|
||||
|
||||
metadata = {
|
||||
"architecture": "lumina",
|
||||
"caption1": info.caption,
|
||||
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
|
||||
else:
|
||||
info.text_encoder_outputs = [
|
||||
hidden_state_i,
|
||||
input_ids_i,
|
||||
attention_mask_i,
|
||||
hidden_state[i].numpy(),
|
||||
input_ids[i].numpy(),
|
||||
attention_mask[i].numpy(),
|
||||
]
|
||||
|
||||
|
||||
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
|
||||
LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors"
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
@@ -291,7 +335,7 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(
|
||||
self, absolute_path: str, image_size: Tuple[int, int]
|
||||
@@ -299,9 +343,12 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
|
||||
+ self.cache_suffix
|
||||
)
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
return "lumina"
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
|
||||
@@ -138,24 +138,32 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||
SD_LATENTS_ST_SUFFIX = "_sd.safetensors"
|
||||
SDXL_LATENTS_ST_SUFFIX = "_sdxl.safetensors"
|
||||
|
||||
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
def __init__(
|
||||
self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.sd = sd
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
if self.cache_format == "safetensors":
|
||||
return self.SD_LATENTS_ST_SUFFIX if self.sd else self.SDXL_LATENTS_ST_SUFFIX
|
||||
else:
|
||||
return self.SD_LATENTS_NPZ_SUFFIX if self.sd else self.SDXL_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
if self.cache_format != "safetensors":
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
return "sd" if self.sd else "sdxl"
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
@@ -255,6 +255,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||
SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_sd3_te.safetensors"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -270,7 +271,8 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
suffix = self.SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
return os.path.splitext(image_abs_path)[0] + suffix
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
@@ -281,27 +283,54 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "lg_out" not in npz:
|
||||
return False
|
||||
if "lg_pooled" not in npz:
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
if not _find_tensor_by_prefix(keys, "lg_out"):
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "lg_pooled"):
|
||||
return False
|
||||
if "clip_l_attn_mask" not in keys or "clip_g_attn_mask" not in keys:
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "t5_out"):
|
||||
return False
|
||||
if "t5_attn_mask" not in keys:
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in keys:
|
||||
return False
|
||||
apply_lg = f.get_tensor("apply_lg_attn_mask").item()
|
||||
if bool(apply_lg) != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in keys:
|
||||
return False
|
||||
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
|
||||
if bool(apply_t5) != self.apply_t5_attn_mask:
|
||||
return False
|
||||
else:
|
||||
npz = np.load(npz_path)
|
||||
if "lg_out" not in npz:
|
||||
return False
|
||||
if "lg_pooled" not in npz:
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -309,6 +338,20 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
lg_out = f.get_tensor(_find_tensor_by_prefix(keys, "lg_out")).numpy()
|
||||
lg_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "lg_pooled")).numpy()
|
||||
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
|
||||
l_attn_mask = f.get_tensor("clip_l_attn_mask").numpy()
|
||||
g_attn_mask = f.get_tensor("clip_g_attn_mask").numpy()
|
||||
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
@@ -339,65 +382,127 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
l_attn_mask_tokens = tokens_and_masks[3]
|
||||
g_attn_mask_tokens = tokens_and_masks[4]
|
||||
t5_attn_mask_tokens = tokens_and_masks[5]
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
if self.cache_format == "safetensors":
|
||||
self._cache_batch_outputs_safetensors(
|
||||
lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
|
||||
)
|
||||
else:
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
|
||||
l_attn_mask = l_attn_mask_tokens.cpu().numpy()
|
||||
g_attn_mask = g_attn_mask_tokens.cpu().numpy()
|
||||
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
|
||||
def _cache_batch_outputs_safetensors(
|
||||
self, lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
|
||||
):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
|
||||
|
||||
lg_out = lg_out.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
lg_pooled = lg_pooled.cpu()
|
||||
l_attn_mask = l_attn_mask_tokens.cpu()
|
||||
g_attn_mask = g_attn_mask_tokens.cpu()
|
||||
t5_attn_mask = t5_attn_mask_tokens.cpu()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
tensors = {}
|
||||
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
|
||||
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
tensors[f"lg_out_{_dtype_to_str(lg_out_i.dtype)}"] = lg_out_i
|
||||
tensors[f"t5_out_{_dtype_to_str(t5_out_i.dtype)}"] = t5_out_i
|
||||
tensors[f"lg_pooled_{_dtype_to_str(lg_pooled_i.dtype)}"] = lg_pooled_i
|
||||
tensors["clip_l_attn_mask"] = l_attn_mask[i]
|
||||
tensors["clip_g_attn_mask"] = g_attn_mask[i]
|
||||
tensors["t5_attn_mask"] = t5_attn_mask[i]
|
||||
tensors["apply_lg_attn_mask"] = torch.tensor(self.apply_lg_attn_mask, dtype=torch.bool)
|
||||
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
|
||||
|
||||
metadata = {
|
||||
"architecture": "sd3",
|
||||
"caption1": info.caption,
|
||||
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
info.text_encoder_outputs = (
|
||||
lg_out[i].numpy(),
|
||||
t5_out[i].numpy(),
|
||||
lg_pooled[i].numpy(),
|
||||
l_attn_mask[i].numpy(),
|
||||
g_attn_mask[i].numpy(),
|
||||
t5_attn_mask[i].numpy(),
|
||||
)
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
SD3_LATENTS_ST_SUFFIX = "_sd3.safetensors"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
return self.SD3_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
+ self.cache_suffix
|
||||
)
|
||||
|
||||
def _get_architecture_name(self) -> str:
|
||||
return "sd3"
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
|
||||
@@ -221,6 +221,7 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_te_outputs.safetensors"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -233,7 +234,8 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
suffix = self.SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
return os.path.splitext(image_abs_path)[0] + suffix
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
@@ -244,9 +246,22 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
if not _find_tensor_by_prefix(keys, "hidden_state1"):
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "hidden_state2"):
|
||||
return False
|
||||
if not _find_tensor_by_prefix(keys, "pool2"):
|
||||
return False
|
||||
else:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -254,6 +269,17 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
if npz_path.endswith(".safetensors"):
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _find_tensor_by_prefix
|
||||
|
||||
with MemoryEfficientSafeOpen(npz_path) as f:
|
||||
keys = f.keys()
|
||||
hidden_state1 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state1")).numpy()
|
||||
hidden_state2 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state2")).numpy()
|
||||
pool2 = f.get_tensor(_find_tensor_by_prefix(keys, "pool2")).numpy()
|
||||
return [hidden_state1, hidden_state2, pool2]
|
||||
|
||||
data = np.load(npz_path)
|
||||
hidden_state1 = data["hidden_state1"]
|
||||
hidden_state2 = data["hidden_state2"]
|
||||
@@ -279,28 +305,68 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy, models, [tokens1, tokens2]
|
||||
)
|
||||
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
if self.cache_format == "safetensors":
|
||||
self._cache_batch_outputs_safetensors(hidden_state1, hidden_state2, pool2, infos)
|
||||
else:
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
|
||||
def _cache_batch_outputs_safetensors(self, hidden_state1, hidden_state2, pool2, infos):
|
||||
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
|
||||
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
|
||||
|
||||
hidden_state1 = hidden_state1.cpu()
|
||||
hidden_state2 = hidden_state2.cpu()
|
||||
pool2 = pool2.cpu()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
)
|
||||
tensors = {}
|
||||
# merge existing file if partial
|
||||
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
|
||||
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
|
||||
hs1 = hidden_state1[i]
|
||||
hs2 = hidden_state2[i]
|
||||
p2 = pool2[i]
|
||||
tensors[f"hidden_state1_{_dtype_to_str(hs1.dtype)}"] = hs1
|
||||
tensors[f"hidden_state2_{_dtype_to_str(hs2.dtype)}"] = hs2
|
||||
tensors[f"pool2_{_dtype_to_str(p2.dtype)}"] = p2
|
||||
|
||||
metadata = {
|
||||
"architecture": "sdxl",
|
||||
"caption1": info.caption,
|
||||
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
|
||||
}
|
||||
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
info.text_encoder_outputs = [
|
||||
hidden_state1[i].numpy(),
|
||||
hidden_state2[i].numpy(),
|
||||
pool2[i].numpy(),
|
||||
]
|
||||
|
||||
@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0 and not cache_supports_dropout
|
||||
subset.caption_dropout_rate > 0
|
||||
and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -4471,7 +4472,10 @@ def verify_training_args(args: argparse.Namespace):
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
|
||||
"""
|
||||
from library.strategy_base import set_cache_format
|
||||
|
||||
enable_high_vram(args)
|
||||
set_cache_format(args.cache_format)
|
||||
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
@@ -4637,6 +4641,14 @@ def add_dataset_arguments(
|
||||
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
|
||||
" / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_format",
|
||||
type=str,
|
||||
default="npz",
|
||||
choices=["npz", "safetensors"],
|
||||
help="format for latent and text encoder output caches (default: npz). safetensors saves in native dtype (e.g. bf16) for smaller files and faster I/O"
|
||||
" / latentおよびtext encoder出力キャッシュの保存形式(デフォルト: npz)。safetensorsはネイティブdtype(例: bf16)で保存し、ファイルサイズ削減と高速化が可能",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_bucket",
|
||||
action="store_true",
|
||||
|
||||
@@ -69,6 +69,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
|
||||
|
||||
strategy_base.set_cache_format(args.cache_format)
|
||||
|
||||
if is_sd or is_sdxl:
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
|
||||
else:
|
||||
|
||||
@@ -156,6 +156,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
text_encoder.eval()
|
||||
|
||||
# build text encoder outputs caching strategy
|
||||
strategy_base.set_cache_format(args.cache_format)
|
||||
|
||||
if is_sdxl:
|
||||
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
|
||||
|
||||
Reference in New Issue
Block a user