feat: faster safetensors load and split safetensor utils

This commit is contained in:
Kohya S
2025-09-13 19:51:38 +09:00
parent 419a9c4af4
commit 8783f8aed3
17 changed files with 459 additions and 234 deletions

View File

@@ -456,13 +456,13 @@ if __name__ == "__main__":
# load clip_l (skip for chroma model) # load clip_l (skip for chroma model)
if args.model_type == "flux": if args.model_type == "flux":
logger.info(f"Loading clip_l from {args.clip_l}...") logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device, disable_mmap=True)
clip_l.eval() clip_l.eval()
else: else:
clip_l = None clip_l = None
logger.info(f"Loading t5xxl from {args.t5xxl}...") logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device, disable_mmap=True)
t5xxl.eval() t5xxl.eval()
# if is_fp8(clip_l_dtype): # if is_fp8(clip_l_dtype):
@@ -471,7 +471,9 @@ if __name__ == "__main__":
# t5xxl = accelerator.prepare(t5xxl) # t5xxl = accelerator.prepare(t5xxl)
# DiT # DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) is_schnell, model = flux_utils.load_flow_model(
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
)
model.eval() model.eval()
logger.info(f"Casting model to {flux_dtype}") logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype model.to(flux_dtype) # make sure model is dtype

View File

@@ -1,13 +1,28 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import gc
import time import time
from typing import Optional, Union, Callable, Tuple from typing import Optional, Union, Callable, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from library.device_utils import clean_memory_on_device
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
def _clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
def synchronize_device(device: torch.device): def _synchronize_device(device: torch.device):
if device.type == "cuda": if device.type == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
elif device.type == "xpu": elif device.type == "xpu":
@@ -71,19 +86,18 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu # device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device(device) _synchronize_device(device)
# cpu to device # cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view module_to_cuda.weight.data = cuda_data_view
synchronize_device(device) _synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device): def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -152,12 +166,15 @@ class Offloader:
# Gradient tensors # Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor] _grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader): class ModelOffloader(Offloader):
""" """
supports forward offloading supports forward offloading
""" """
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False): def __init__(
self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False
):
super().__init__(len(blocks), blocks_to_swap, device, debug) super().__init__(len(blocks), blocks_to_swap, device, debug)
# register backward hooks # register backward hooks
@@ -172,7 +189,9 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles: for handle in self.remove_handles:
handle.remove() handle.remove()
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: def create_backward_hook(
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index # -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1 num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -213,8 +232,8 @@ class ModelOffloader(Offloader):
b.to(self.device) # move block to device first b.to(self.device) # move block to device first
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
synchronize_device(self.device) _synchronize_device(self.device)
clean_memory_on_device(self.device) _clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int): def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0: if self.blocks_to_swap is None or self.blocks_to_swap == 0:

View File

@@ -1,7 +1,9 @@
import functools import functools
import gc import gc
from typing import Optional, Union
import torch import torch
try: try:
# intel gpu support for pytorch older than 2.5 # intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5 # ipex is not needed after pytorch 2.5
@@ -36,12 +38,15 @@ def clean_memory():
torch.mps.empty_cache() torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device): def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
r""" r"""
Clean memory on the specified device, will be called from training scripts. Clean memory on the specified device, will be called from training scripts.
""" """
gc.collect() gc.collect()
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
# device may "cuda" or "cuda:0", so we need to check the type of device # device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda": if device.type == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -51,6 +56,19 @@ def clean_memory_on_device(device: torch.device):
torch.mps.empty_cache() torch.mps.empty_cache()
def synchronize_device(device: Optional[Union[str, torch.device]]):
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device: def get_preferred_device() -> torch.device:
r""" r"""

View File

@@ -16,10 +16,11 @@ from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util from library import flux_models, flux_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.safetensors_utils import mem_eff_save_file
init_ipex() init_ipex()
from .utils import setup_logging, mem_eff_save_file from .utils import setup_logging
setup_logging() setup_logging()
import logging import logging

View File

@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from library import flux_models from library import flux_models
from library.utils import load_safetensors from library.safetensors_utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1" MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev" MODEL_NAME_DEV = "dev"
@@ -124,7 +124,7 @@ def load_flow_model(
logger.info(f"Loading state dict from {ckpt_path}") logger.info(f"Loading state dict from {ckpt_path}")
sd = {} sd = {}
for ckpt_path in ckpt_paths: for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL # convert Diffusers to BFL
if is_diffusers: if is_diffusers:

View File

@@ -18,10 +18,11 @@ from library import lumina_models, strategy_base, strategy_lumina, train_util
from library.flux_models import AutoEncoder from library.flux_models import AutoEncoder
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.safetensors_utils import mem_eff_save_file
init_ipex() init_ipex()
from .utils import setup_logging, mem_eff_save_file from .utils import setup_logging
setup_logging() setup_logging()
import logging import logging

View File

@@ -12,7 +12,7 @@ from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging from library.utils import setup_logging
from library import lumina_models, flux_models from library import lumina_models, flux_models
from library.utils import load_safetensors from library.safetensors_utils import load_safetensors
import logging import logging
setup_logging() setup_logging()

View File

@@ -0,0 +1,352 @@
import os
import re
import numpy as np
import torch
import json
import struct
from typing import Dict, Any, Union, Optional
from safetensors.torch import load_file
from library.device_utils import synchronize_device
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
# print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
"""Memory-efficient reader for safetensors files.
This class provides a memory-efficient way to read tensors from safetensors files
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
def __init__(self, filename):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
"""Enter context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager and close file."""
self.file.close()
def keys(self):
"""Get all tensor keys in the file.
Returns:
list: List of tensor names (excludes metadata).
"""
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
"""Get metadata from the file.
Returns:
Dict[str, str]: Metadata dictionary.
"""
return self.header.get("__metadata__", {})
def _read_header(self):
"""Read and parse the header from the safetensors file.
Returns:
tuple: (header_dict, header_size) containing parsed header and its size.
"""
# Read header size (8 bytes, little-endian unsigned long long)
header_size = struct.unpack("<Q", self.file.read(8))[0]
# Read and decode header JSON
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def get_tensor(self, key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
"""Load a tensor from the file with memory-efficient strategies.
**Note:**
If device is 'cuda' , the transfer to GPU is done efficiently using pinned memory and non-blocking transfer.
So you must ensure that the transfer is completed before using the tensor (e.g., by `torch.cuda.synchronize()`).
If the tensor is large (>10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies.
Args:
key (str): Name of the tensor to load.
device (Optional[torch.device]): Target device for the tensor.
dtype (Optional[torch.dtype]): Target dtype for the tensor.
Returns:
torch.Tensor: The loaded tensor.
Raises:
KeyError: If the tensor key is not found in the file.
"""
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
num_bytes = offset_end - offset_start
original_dtype = self._get_torch_dtype(metadata["dtype"])
target_dtype = dtype if dtype is not None else original_dtype
# Handle empty tensors
if num_bytes == 0:
return torch.empty(metadata["shape"], dtype=target_dtype, device=device)
# Determine if we should use pinned memory for GPU transfer
non_blocking = device is not None and device.type == "cuda"
# Calculate absolute file offset
tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size
# Memory mapping strategy for large tensors to GPU
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
del mm
# Deserialize tensor (view and reshape)
cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape
del byte_tensor
# Transfer to target device and dtype
gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking)
del cpu_tensor
return gpu_tensor
# Standard file reading strategy for smaller tensors or CPU target
# seek to the specified position
self.file.seek(tensor_offset)
# read directly into a numpy array by numpy.fromfile without intermediate copy
numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes)
byte_tensor = torch.from_numpy(numpy_array)
del numpy_array
# deserialize (view and reshape)
deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata)
del byte_tensor
# cast to target dtype and move to device
return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking)
def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict):
"""Deserialize byte tensor to the correct shape and dtype.
Args:
byte_tensor (torch.Tensor): Raw byte tensor from file.
metadata (Dict): Tensor metadata containing dtype and shape info.
Returns:
torch.Tensor: Deserialized tensor with correct shape and dtype.
"""
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
# Handle special float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# Standard conversion: view as target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
"""Convert string dtype to PyTorch dtype.
Args:
dtype_str (str): String representation of the dtype.
Returns:
torch.dtype: Corresponding PyTorch dtype.
"""
# Standard dtype mappings
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# Add float8 types if available in PyTorch version
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
"""Convert byte tensor to float8 format if supported.
Args:
byte_tensor (torch.Tensor): Raw byte tensor.
dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3").
shape (tuple): Target tensor shape.
Returns:
torch.Tensor: Tensor with float8 dtype.
Raises:
ValueError: If float8 type is not supported in current PyTorch version.
"""
# Convert to specific float8 types if available
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# Float8 not supported in this PyTorch version
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
"""
Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix.
dtype is as is, no conversion is done.
"""
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with: Optional[str] = None) -> Optional[str]:
"""
Find a key in a safetensors file that starts with `starts_with` and ends with `ends_with`.
If `starts_with` is None, it will match any key.
If `ends_with` is None, it will match any key.
Returns the first matching key or None if no key matches.
"""
with MemoryEfficientSafeOpen(safetensors_file) as f:
for key in f.keys():
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None

View File

@@ -23,7 +23,7 @@ from library import sdxl_model_util
# region models # region models
# TODO remove dependency on flux_utils # TODO remove dependency on flux_utils
from library.utils import load_safetensors from library.safetensors_utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
@@ -246,7 +246,7 @@ def load_vae(
vae_sd = {} vae_sd = {}
if vae_path: if vae_path:
logger.info(f"Loading VAE from {vae_path}...") logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap) vae_sd = load_safetensors(vae_path, device, disable_mmap, dtype=vae_dtype)
else: else:
# remove prefix "first_stage_model." # remove prefix "first_stage_model."
vae_sd = {} vae_sd = {}

View File

@@ -2,8 +2,6 @@ import logging
import sys import sys
import threading import threading
from typing import * from typing import *
import json
import struct
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -14,7 +12,7 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
import cv2 import cv2
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs): def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start() threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -88,6 +86,7 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info(msg_init) logger.info(msg_init)
setup_logging() setup_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -190,190 +189,6 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
raise ValueError(f"Unsupported dtype: {s}") raise ValueError(f"Unsupported dtype: {s}")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion # endregion
# region Image utils # region Image utils
@@ -398,7 +213,14 @@ def pil_resize(image, size, interpolation):
return resized_cv2 return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): def resize_image(
image: np.ndarray,
width: int,
height: int,
resized_width: int,
resized_height: int,
resize_interpolation: Optional[str] = None,
):
""" """
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
@@ -472,6 +294,7 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
else: else:
return None return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
""" """
Convert interpolation value to PIL interpolation Convert interpolation value to PIL interpolation
@@ -503,12 +326,14 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
else: else:
return None return None
def validate_interpolation_fn(interpolation_str: str) -> bool: def validate_interpolation_fn(interpolation_str: str) -> bool:
""" """
Check if a interpolation function is supported Check if a interpolation function is supported
""" """
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion # endregion
# TODO make inf_utils.py # TODO make inf_utils.py
@@ -642,7 +467,9 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample") raise NotImplementedError("prediction_type not implemented yet: sample")
else: else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
sigma_from = self.sigmas[self.step_index] sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1] sigma_to = self.sigmas[self.step_index + 1]

View File

@@ -9,7 +9,8 @@ from safetensors import safe_open
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file from library.utils import setup_logging, str_to_dtype
from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging() setup_logging()
import logging import logging

View File

@@ -28,7 +28,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_sd3 from library import sd3_models, sd3_utils, strategy_sd3
from library.utils import load_safetensors from library.safetensors_utils import load_safetensors
def get_noise(seed, latent, device="cpu"): def get_noise(seed, latent, device="cpu"):

View File

@@ -14,6 +14,7 @@ from tqdm import tqdm
import torch import torch
from library import utils from library import utils
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.safetensors_utils import load_safetensors
init_ipex() init_ipex()
@@ -206,7 +207,7 @@ def train(args):
# t5xxl_dtype = weight_dtype # t5xxl_dtype = weight_dtype
model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx)
if args.clip_l is None: if args.clip_l is None:
sd3_state_dict = utils.load_safetensors( sd3_state_dict = load_safetensors(
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
) )
else: else:
@@ -322,7 +323,7 @@ def train(args):
# load VAE for caching latents # load VAE for caching latents
if sd3_state_dict is None: if sd3_state_dict is None:
logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}")
sd3_state_dict = utils.load_safetensors( sd3_state_dict = load_safetensors(
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
) )

View File

@@ -8,6 +8,7 @@ import torch
from accelerate import Accelerator from accelerate import Accelerator
from library import sd3_models, strategy_sd3, utils from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.safetensors_utils import load_safetensors
init_ipex() init_ipex()
@@ -77,7 +78,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_base else weight_dtype loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
state_dict = utils.load_safetensors( state_dict = load_safetensors(
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
) )
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from library.custom_offloading_utils import ( from library.custom_offloading_utils import (
synchronize_device, _synchronize_device,
swap_weight_devices_cuda, swap_weight_devices_cuda,
swap_weight_devices_no_cuda, swap_weight_devices_no_cuda,
weighs_to_device, weighs_to_device,
@@ -50,21 +50,21 @@ class SimpleModel(nn.Module):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_synchronize(mock_cuda_sync): def test_cuda_synchronize(mock_cuda_sync):
device = torch.device('cuda') device = torch.device('cuda')
synchronize_device(device) _synchronize_device(device)
mock_cuda_sync.assert_called_once() mock_cuda_sync.assert_called_once()
@patch('torch.xpu.synchronize') @patch('torch.xpu.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available") @pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
def test_xpu_synchronize(mock_xpu_sync): def test_xpu_synchronize(mock_xpu_sync):
device = torch.device('xpu') device = torch.device('xpu')
synchronize_device(device) _synchronize_device(device)
mock_xpu_sync.assert_called_once() mock_xpu_sync.assert_called_once()
@patch('torch.mps.synchronize') @patch('torch.mps.synchronize')
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available") @pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
def test_mps_synchronize(mock_mps_sync): def test_mps_synchronize(mock_mps_sync):
device = torch.device('mps') device = torch.device('mps')
synchronize_device(device) _synchronize_device(device)
mock_mps_sync.assert_called_once() mock_mps_sync.assert_called_once()
@@ -111,7 +111,7 @@ def test_swap_weight_devices_cuda():
@patch('library.custom_offloading_utils.synchronize_device') @patch('library.custom_offloading_utils._synchronize_device')
def test_swap_weight_devices_no_cuda(mock_sync_device): def test_swap_weight_devices_no_cuda(mock_sync_device):
device = torch.device('cpu') device = torch.device('cpu')
layer_to_cpu = SimpleModel() layer_to_cpu = SimpleModel()
@@ -121,7 +121,7 @@ def test_swap_weight_devices_no_cuda(mock_sync_device):
with patch('torch.Tensor.copy_'): with patch('torch.Tensor.copy_'):
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda) swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
# Verify synchronize_device was called twice # Verify _synchronize_device was called twice
assert mock_sync_device.call_count == 2 assert mock_sync_device.call_count == 2
@@ -279,8 +279,8 @@ def test_backward_hook_execution(mock_wait, mock_submit):
@patch('library.custom_offloading_utils.weighs_to_device') @patch('library.custom_offloading_utils.weighs_to_device')
@patch('library.custom_offloading_utils.synchronize_device') @patch('library.custom_offloading_utils._synchronize_device')
@patch('library.custom_offloading_utils.clean_memory_on_device') @patch('library.custom_offloading_utils._clean_memory_on_device')
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader): def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
model = SimpleModel(4) model = SimpleModel(4)
blocks = model.blocks blocks = model.blocks
@@ -291,7 +291,7 @@ def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weight
# Check that weighs_to_device was called for each block # Check that weighs_to_device was called for each block
assert mock_weights_to_device.call_count == 4 assert mock_weights_to_device.call_count == 4
# Check that synchronize_device and clean_memory_on_device were called # Check that _synchronize_device and _clean_memory_on_device were called
mock_sync.assert_called_once_with(model_offloader.device) mock_sync.assert_called_once_with(model_offloader.device)
mock_clean.assert_called_once_with(model_offloader.device) mock_clean.assert_called_once_with(model_offloader.device)

View File

@@ -30,7 +30,8 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from library import flux_utils from library import flux_utils
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file from library.utils import setup_logging, str_to_dtype
from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging() setup_logging()
import logging import logging

View File

@@ -6,7 +6,8 @@ import torch
from safetensors.torch import safe_open from safetensors.torch import safe_open
from library.utils import setup_logging from library.utils import setup_logging
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype from library.utils import str_to_dtype
from library.safetensors_utils import load_safetensors, mem_eff_save_file
setup_logging() setup_logging()
import logging import logging