From 8783f8aed395e82678e0f7a48b0415b95e819484 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 13 Sep 2025 19:51:38 +0900 Subject: [PATCH 1/3] feat: faster safetensors load and split safetensor utils --- flux_minimal_inference.py | 8 +- library/custom_offloading_utils.py | 37 ++- library/device_utils.py | 22 +- library/flux_train_utils.py | 3 +- library/flux_utils.py | 4 +- library/lumina_train_util.py | 3 +- library/lumina_util.py | 2 +- library/safetensors_utils.py | 352 ++++++++++++++++++++++++++ library/sd3_utils.py | 4 +- library/utils.py | 221 ++-------------- networks/flux_merge_lora.py | 3 +- sd3_minimal_inference.py | 2 +- sd3_train.py | 5 +- sd3_train_network.py | 3 +- tests/test_custom_offloading_utils.py | 18 +- tools/convert_diffusers_to_flux.py | 3 +- tools/merge_sd3_safetensors.py | 3 +- 17 files changed, 459 insertions(+), 234 deletions(-) create mode 100644 library/safetensors_utils.py diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index d5f2d8d9..0664b3c7 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -456,13 +456,13 @@ if __name__ == "__main__": # load clip_l (skip for chroma model) if args.model_type == "flux": logger.info(f"Loading clip_l from {args.clip_l}...") - clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device, disable_mmap=True) clip_l.eval() else: clip_l = None logger.info(f"Loading t5xxl from {args.t5xxl}...") - t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device, disable_mmap=True) t5xxl.eval() # if is_fp8(clip_l_dtype): @@ -471,7 +471,9 @@ if __name__ == "__main__": # t5xxl = accelerator.prepare(t5xxl) # DiT - is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) + is_schnell, model = flux_utils.load_flow_model( + args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type + ) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 55ff08b6..fce3747e 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -1,13 +1,28 @@ from concurrent.futures import ThreadPoolExecutor +import gc import time from typing import Optional, Union, Callable, Tuple import torch import torch.nn as nn -from library.device_utils import clean_memory_on_device + +# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py +def _clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() -def synchronize_device(device: torch.device): +def _synchronize_device(device: torch.device): if device.type == "cuda": torch.cuda.synchronize() elif device.type == "xpu": @@ -71,19 +86,18 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - # device to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) - synchronize_device(device) + _synchronize_device(device) # cpu to device for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view - synchronize_device(device) + _synchronize_device(device) def weighs_to_device(layer: nn.Module, device: torch.device): @@ -152,12 +166,15 @@ class Offloader: # Gradient tensors _grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor] + class ModelOffloader(Offloader): """ supports forward offloading """ - def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False): + def __init__( + self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False + ): super().__init__(len(blocks), blocks_to_swap, device, debug) # register backward hooks @@ -172,7 +189,9 @@ class ModelOffloader(Offloader): for handle in self.remove_handles: handle.remove() - def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: + def create_backward_hook( + self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int + ) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: # -1 for 0-based index num_blocks_propagated = self.num_blocks - block_index - 1 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap @@ -213,8 +232,8 @@ class ModelOffloader(Offloader): b.to(self.device) # move block to device first weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu - synchronize_device(self.device) - clean_memory_on_device(self.device) + _synchronize_device(self.device) + _clean_memory_on_device(self.device) def wait_for_block(self, block_idx: int): if self.blocks_to_swap is None or self.blocks_to_swap == 0: diff --git a/library/device_utils.py b/library/device_utils.py index d2e19745..9d7757ed 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -1,7 +1,9 @@ import functools import gc +from typing import Optional, Union import torch + try: # intel gpu support for pytorch older than 2.5 # ipex is not needed after pytorch 2.5 @@ -36,12 +38,15 @@ def clean_memory(): torch.mps.empty_cache() -def clean_memory_on_device(device: torch.device): +def clean_memory_on_device(device: Optional[Union[str, torch.device]]): r""" Clean memory on the specified device, will be called from training scripts. """ gc.collect() - + if device is None: + return + if isinstance(device, str): + device = torch.device(device) # device may "cuda" or "cuda:0", so we need to check the type of device if device.type == "cuda": torch.cuda.empty_cache() @@ -51,6 +56,19 @@ def clean_memory_on_device(device: torch.device): torch.mps.empty_cache() +def synchronize_device(device: Optional[Union[str, torch.device]]): + if device is None: + return + if isinstance(device, str): + device = torch.device(device) + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: r""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f3eb8199..06fe0b95 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -16,10 +16,11 @@ from safetensors.torch import save_file from library import flux_models, flux_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device +from library.safetensors_utils import mem_eff_save_file init_ipex() -from .utils import setup_logging, mem_eff_save_file +from .utils import setup_logging setup_logging() import logging diff --git a/library/flux_utils.py b/library/flux_utils.py index 22054854..410b34ce 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) from library import flux_models -from library.utils import load_safetensors +from library.safetensors_utils import load_safetensors MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" @@ -124,7 +124,7 @@ def load_flow_model( logger.info(f"Loading state dict from {ckpt_path}") sd = {} for ckpt_path in ckpt_paths: - sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)) # convert Diffusers to BFL if is_diffusers: diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 0645a8ae..d5d5db05 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -18,10 +18,11 @@ from library import lumina_models, strategy_base, strategy_lumina, train_util from library.flux_models import AutoEncoder from library.device_utils import init_ipex, clean_memory_on_device from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.safetensors_utils import mem_eff_save_file init_ipex() -from .utils import setup_logging, mem_eff_save_file +from .utils import setup_logging setup_logging() import logging diff --git a/library/lumina_util.py b/library/lumina_util.py index 87853ef6..f7f3c823 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -12,7 +12,7 @@ from transformers import Gemma2Config, Gemma2Model from library.utils import setup_logging from library import lumina_models, flux_models -from library.utils import load_safetensors +from library.safetensors_utils import load_safetensors import logging setup_logging() diff --git a/library/safetensors_utils.py b/library/safetensors_utils.py new file mode 100644 index 00000000..dcd2309e --- /dev/null +++ b/library/safetensors_utils.py @@ -0,0 +1,352 @@ +import os +import re +import numpy as np +import torch +import json +import struct +from typing import Dict, Any, Union, Optional + +from safetensors.torch import load_file + +from library.device_utils import synchronize_device + + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + # print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" Dict[str, str]: + """Get metadata from the file. + + Returns: + Dict[str, str]: Metadata dictionary. + """ + return self.header.get("__metadata__", {}) + + def _read_header(self): + """Read and parse the header from the safetensors file. + + Returns: + tuple: (header_dict, header_size) containing parsed header and its size. + """ + # Read header size (8 bytes, little-endian unsigned long long) + header_size = struct.unpack("10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies. + + Args: + key (str): Name of the tensor to load. + device (Optional[torch.device]): Target device for the tensor. + dtype (Optional[torch.dtype]): Target dtype for the tensor. + + Returns: + torch.Tensor: The loaded tensor. + + Raises: + KeyError: If the tensor key is not found in the file. + """ + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + num_bytes = offset_end - offset_start + + original_dtype = self._get_torch_dtype(metadata["dtype"]) + target_dtype = dtype if dtype is not None else original_dtype + + # Handle empty tensors + if num_bytes == 0: + return torch.empty(metadata["shape"], dtype=target_dtype, device=device) + + # Determine if we should use pinned memory for GPU transfer + non_blocking = device is not None and device.type == "cuda" + + # Calculate absolute file offset + tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size + + # Memory mapping strategy for large tensors to GPU + # Use memmap for large tensors to avoid intermediate copies. + # If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired. + # So we only use memmap if device is not cpu. + if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu": + # Create memory map for zero-copy reading + mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,)) + byte_tensor = torch.from_numpy(mm) # zero copy + del mm + + # Deserialize tensor (view and reshape) + cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape + del byte_tensor + + # Transfer to target device and dtype + gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) + del cpu_tensor + return gpu_tensor + + # Standard file reading strategy for smaller tensors or CPU target + # seek to the specified position + self.file.seek(tensor_offset) + + # read directly into a numpy array by numpy.fromfile without intermediate copy + numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes) + byte_tensor = torch.from_numpy(numpy_array) + del numpy_array + + # deserialize (view and reshape) + deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata) + del byte_tensor + + # cast to target dtype and move to device + return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) + + def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict): + """Deserialize byte tensor to the correct shape and dtype. + + Args: + byte_tensor (torch.Tensor): Raw byte tensor from file. + metadata (Dict): Tensor metadata containing dtype and shape info. + + Returns: + torch.Tensor: Deserialized tensor with correct shape and dtype. + """ + dtype = self._get_torch_dtype(metadata["dtype"]) + shape = metadata["shape"] + + # Handle special float8 types + if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: + return self._convert_float8(byte_tensor, metadata["dtype"], shape) + + # Standard conversion: view as target dtype and reshape + return byte_tensor.view(dtype).reshape(shape) + + @staticmethod + def _get_torch_dtype(dtype_str): + """Convert string dtype to PyTorch dtype. + + Args: + dtype_str (str): String representation of the dtype. + + Returns: + torch.dtype: Corresponding PyTorch dtype. + """ + # Standard dtype mappings + dtype_map = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + } + # Add float8 types if available in PyTorch version + if hasattr(torch, "float8_e5m2"): + dtype_map["F8_E5M2"] = torch.float8_e5m2 + if hasattr(torch, "float8_e4m3fn"): + dtype_map["F8_E4M3"] = torch.float8_e4m3fn + return dtype_map.get(dtype_str) + + @staticmethod + def _convert_float8(byte_tensor, dtype_str, shape): + """Convert byte tensor to float8 format if supported. + + Args: + byte_tensor (torch.Tensor): Raw byte tensor. + dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3"). + shape (tuple): Target tensor shape. + + Returns: + torch.Tensor: Tensor with float8 dtype. + + Raises: + ValueError: If float8 type is not supported in current PyTorch version. + """ + # Convert to specific float8 types if available + if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): + return byte_tensor.view(torch.float8_e5m2).reshape(shape) + elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): + return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) + else: + # Float8 not supported in this PyTorch version + raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") + + +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None +) -> dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + device = torch.device(device) if device is not None else None + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key, device=device, dtype=dtype) + synchronize_device(device) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + +def load_split_weights( + file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None +) -> Dict[str, torch.Tensor]: + """ + Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix. + dtype is as is, no conversion is done. + """ + device = torch.device(device) + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + basename = os.path.basename(file_path) + match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) + if match: + prefix = basename[: match.start(2)] + count = int(match.group(3)) + state_dict = {} + for i in range(count): + filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors" + filepath = os.path.join(os.path.dirname(file_path), filename) + if os.path.exists(filepath): + state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype)) + else: + raise FileNotFoundError(f"File {filepath} not found") + else: + state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + return state_dict + + +def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with: Optional[str] = None) -> Optional[str]: + """ + Find a key in a safetensors file that starts with `starts_with` and ends with `ends_with`. + If `starts_with` is None, it will match any key. + If `ends_with` is None, it will match any key. + Returns the first matching key or None if no key matches. + """ + with MemoryEfficientSafeOpen(safetensors_file) as f: + for key in f.keys(): + if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)): + return key + return None diff --git a/library/sd3_utils.py b/library/sd3_utils.py index d2ea6fff..5fbaa4c3 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -23,7 +23,7 @@ from library import sdxl_model_util # region models # TODO remove dependency on flux_utils -from library.utils import load_safetensors +from library.safetensors_utils import load_safetensors from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl @@ -246,7 +246,7 @@ def load_vae( vae_sd = {} if vae_path: logger.info(f"Loading VAE from {vae_path}...") - vae_sd = load_safetensors(vae_path, device, disable_mmap) + vae_sd = load_safetensors(vae_path, device, disable_mmap, dtype=vae_dtype) else: # remove prefix "first_stage_model." vae_sd = {} diff --git a/library/utils.py b/library/utils.py index d0586b84..296fc415 100644 --- a/library/utils.py +++ b/library/utils.py @@ -2,8 +2,6 @@ import logging import sys import threading from typing import * -import json -import struct import torch import torch.nn as nn @@ -14,7 +12,7 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest import cv2 from PIL import Image import numpy as np -from safetensors.torch import load_file + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -88,6 +86,7 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) + setup_logging() logger = logging.getLogger(__name__) @@ -190,190 +189,6 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) raise ValueError(f"Unsupported dtype: {s}") -def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): - """ - memory efficient save file - """ - - _TYPES = { - torch.float64: "F64", - torch.float32: "F32", - torch.float16: "F16", - torch.bfloat16: "BF16", - torch.int64: "I64", - torch.int32: "I32", - torch.int16: "I16", - torch.int8: "I8", - torch.uint8: "U8", - torch.bool: "BOOL", - getattr(torch, "float8_e5m2", None): "F8_E5M2", - getattr(torch, "float8_e4m3fn", None): "F8_E4M3", - } - _ALIGN = 256 - - def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: - validated = {} - for key, value in metadata.items(): - if not isinstance(key, str): - raise ValueError(f"Metadata key must be a string, got {type(key)}") - if not isinstance(value, str): - print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") - validated[key] = str(value) - else: - validated[key] = value - return validated - - print(f"Using memory efficient save file: {filename}") - - header = {} - offset = 0 - if metadata: - header["__metadata__"] = validate_metadata(metadata) - for k, v in tensors.items(): - if v.numel() == 0: # empty tensor - header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} - else: - size = v.numel() * v.element_size() - header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} - offset += size - - hjson = json.dumps(header).encode("utf-8") - hjson += b" " * (-(len(hjson) + 8) % _ALIGN) - - with open(filename, "wb") as f: - f.write(struct.pack(" Dict[str, str]: - return self.header.get("__metadata__", {}) - - def get_tensor(self, key): - if key not in self.header: - raise KeyError(f"Tensor '{key}' not found in the file") - - metadata = self.header[key] - offset_start, offset_end = metadata["data_offsets"] - - if offset_start == offset_end: - tensor_bytes = None - else: - # adjust offset by header size - self.file.seek(self.header_size + 8 + offset_start) - tensor_bytes = self.file.read(offset_end - offset_start) - - return self._deserialize_tensor(tensor_bytes, metadata) - - def _read_header(self): - header_size = struct.unpack(" dict[str, torch.Tensor]: - if disable_mmap: - # return safetensors.torch.load(open(path, "rb").read()) - # use experimental loader - # logger.info(f"Loading without mmap (experimental)") - state_dict = {} - with MemoryEfficientSafeOpen(path) as f: - for key in f.keys(): - state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) - return state_dict - else: - try: - state_dict = load_file(path, device=device) - except: - state_dict = load_file(path) # prevent device invalid Error - if dtype is not None: - for key in state_dict.keys(): - state_dict[key] = state_dict[key].to(dtype=dtype) - return state_dict - - # endregion # region Image utils @@ -398,7 +213,14 @@ def pil_resize(image, size, interpolation): return resized_cv2 -def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): +def resize_image( + image: np.ndarray, + width: int, + height: int, + resized_width: int, + resized_height: int, + resize_interpolation: Optional[str] = None, +): """ Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. @@ -449,29 +271,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 """ if interpolation is None: - return None + return None if interpolation == "lanczos" or interpolation == "lanczos4": - # Lanczos interpolation over 8x8 neighborhood + # Lanczos interpolation over 8x8 neighborhood return cv2.INTER_LANCZOS4 elif interpolation == "nearest": - # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. + # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. return cv2.INTER_NEAREST_EXACT elif interpolation == "bilinear" or interpolation == "linear": # bilinear interpolation return cv2.INTER_LINEAR elif interpolation == "bicubic" or interpolation == "cubic": - # bicubic interpolation + # bicubic interpolation return cv2.INTER_CUBIC elif interpolation == "area": - # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA elif interpolation == "box": - # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA else: return None + def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: """ Convert interpolation value to PIL interpolation @@ -479,7 +302,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters """ if interpolation is None: - return None + return None if interpolation == "lanczos": return Image.Resampling.LANCZOS @@ -493,7 +316,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. return Image.Resampling.BICUBIC elif interpolation == "area": - # Image.Resampling.BOX may be more appropriate if upscaling + # Image.Resampling.BOX may be more appropriate if upscaling # Area interpolation is related to cv2.INTER_AREA # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. return Image.Resampling.HAMMING @@ -503,12 +326,14 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp else: return None + def validate_interpolation_fn(interpolation_str: str) -> bool: """ Check if a interpolation function is supported """ return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + # endregion # TODO make inf_utils.py @@ -642,7 +467,9 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: - raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) sigma_from = self.sigmas[self.step_index] sigma_to = self.sigmas[self.step_index + 1] diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 5e100a3b..855c0ed9 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -9,7 +9,8 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm -from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file +from library.utils import setup_logging, str_to_dtype +from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() import logging diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 86dba246..d7b97a59 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -28,7 +28,7 @@ import logging logger = logging.getLogger(__name__) from library import sd3_models, sd3_utils, strategy_sd3 -from library.utils import load_safetensors +from library.safetensors_utils import load_safetensors def get_noise(seed, latent, device="cpu"): diff --git a/sd3_train.py b/sd3_train.py index 355e13dd..c6a2fdd8 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -14,6 +14,7 @@ from tqdm import tqdm import torch from library import utils from library.device_utils import init_ipex, clean_memory_on_device +from library.safetensors_utils import load_safetensors init_ipex() @@ -206,7 +207,7 @@ def train(args): # t5xxl_dtype = weight_dtype model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) if args.clip_l is None: - sd3_state_dict = utils.load_safetensors( + sd3_state_dict = load_safetensors( args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype ) else: @@ -322,7 +323,7 @@ def train(args): # load VAE for caching latents if sd3_state_dict is None: logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") - sd3_state_dict = utils.load_safetensors( + sd3_state_dict = load_safetensors( args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype ) diff --git a/sd3_train_network.py b/sd3_train_network.py index cdb7aa4e..c9b06a38 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -8,6 +8,7 @@ import torch from accelerate import Accelerator from library import sd3_models, strategy_sd3, utils from library.device_utils import init_ipex, clean_memory_on_device +from library.safetensors_utils import load_safetensors init_ipex() @@ -77,7 +78,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - state_dict = utils.load_safetensors( + state_dict = load_safetensors( args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype ) mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") diff --git a/tests/test_custom_offloading_utils.py b/tests/test_custom_offloading_utils.py index 5fa40b76..8c23bdf5 100644 --- a/tests/test_custom_offloading_utils.py +++ b/tests/test_custom_offloading_utils.py @@ -4,7 +4,7 @@ import torch.nn as nn from unittest.mock import patch, MagicMock from library.custom_offloading_utils import ( - synchronize_device, + _synchronize_device, swap_weight_devices_cuda, swap_weight_devices_no_cuda, weighs_to_device, @@ -50,21 +50,21 @@ class SimpleModel(nn.Module): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_cuda_synchronize(mock_cuda_sync): device = torch.device('cuda') - synchronize_device(device) + _synchronize_device(device) mock_cuda_sync.assert_called_once() @patch('torch.xpu.synchronize') @pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available") def test_xpu_synchronize(mock_xpu_sync): device = torch.device('xpu') - synchronize_device(device) + _synchronize_device(device) mock_xpu_sync.assert_called_once() @patch('torch.mps.synchronize') @pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available") def test_mps_synchronize(mock_mps_sync): device = torch.device('mps') - synchronize_device(device) + _synchronize_device(device) mock_mps_sync.assert_called_once() @@ -111,7 +111,7 @@ def test_swap_weight_devices_cuda(): -@patch('library.custom_offloading_utils.synchronize_device') +@patch('library.custom_offloading_utils._synchronize_device') def test_swap_weight_devices_no_cuda(mock_sync_device): device = torch.device('cpu') layer_to_cpu = SimpleModel() @@ -121,7 +121,7 @@ def test_swap_weight_devices_no_cuda(mock_sync_device): with patch('torch.Tensor.copy_'): swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda) - # Verify synchronize_device was called twice + # Verify _synchronize_device was called twice assert mock_sync_device.call_count == 2 @@ -279,8 +279,8 @@ def test_backward_hook_execution(mock_wait, mock_submit): @patch('library.custom_offloading_utils.weighs_to_device') -@patch('library.custom_offloading_utils.synchronize_device') -@patch('library.custom_offloading_utils.clean_memory_on_device') +@patch('library.custom_offloading_utils._synchronize_device') +@patch('library.custom_offloading_utils._clean_memory_on_device') def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader): model = SimpleModel(4) blocks = model.blocks @@ -291,7 +291,7 @@ def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weight # Check that weighs_to_device was called for each block assert mock_weights_to_device.call_count == 4 - # Check that synchronize_device and clean_memory_on_device were called + # Check that _synchronize_device and _clean_memory_on_device were called mock_sync.assert_called_once_with(model_offloader.device) mock_clean.assert_called_once_with(model_offloader.device) diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py index 65ba7321..a11093c9 100644 --- a/tools/convert_diffusers_to_flux.py +++ b/tools/convert_diffusers_to_flux.py @@ -30,7 +30,8 @@ import torch from tqdm import tqdm from library import flux_utils -from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file +from library.utils import setup_logging, str_to_dtype +from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() import logging diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py index 6bc1003e..6ec045dd 100644 --- a/tools/merge_sd3_safetensors.py +++ b/tools/merge_sd3_safetensors.py @@ -6,7 +6,8 @@ import torch from safetensors.torch import safe_open from library.utils import setup_logging -from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype +from library.utils import str_to_dtype +from library.safetensors_utils import load_safetensors, mem_eff_save_file setup_logging() import logging From e1c666e97f99f50e381ab88b8878392ca26870bb Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 13 Sep 2025 20:03:55 +0900 Subject: [PATCH 2/3] Update library/safetensors_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- library/safetensors_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/safetensors_utils.py b/library/safetensors_utils.py index dcd2309e..c65cdfab 100644 --- a/library/safetensors_utils.py +++ b/library/safetensors_utils.py @@ -44,7 +44,6 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: validated[key] = value return validated - # print(f"Using memory efficient save file: {filename}") header = {} offset = 0 From 4568631b43f348ea4360b021315a3da8064f3d7b Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 13 Sep 2025 20:05:39 +0900 Subject: [PATCH 3/3] docs: update README to reflect improved loading speed of .safetensors files --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 843cf71b..da38a241 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,13 @@ For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirem If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet). -- [FLUX.1 training](#flux1-training) -- [SD3 training](#sd3-training) - ### Recent Updates +Sep 13, 2025: +- The loading speed of `.safetensors` files has been improved for SD3, FLUX.1 and Lumina. See [PR #2200](https://github.com/kohya-ss/sd-scripts/pull/2200) for more details. + - Model loading can be up to 1.5 times faster. + - This is a wide-ranging update, so there may be bugs. Please let us know if you encounter any issues. + Sep 4, 2025: - The information about FLUX.1 and SD3/SD3.5 training that was described in the README has been organized and divided into the following documents: - [LoRA Training Overview](./docs/train_network.md)