mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: faster safetensors load and split safetensor utils
This commit is contained in:
@@ -456,13 +456,13 @@ if __name__ == "__main__":
|
||||
# load clip_l (skip for chroma model)
|
||||
if args.model_type == "flux":
|
||||
logger.info(f"Loading clip_l from {args.clip_l}...")
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device, disable_mmap=True)
|
||||
clip_l.eval()
|
||||
else:
|
||||
clip_l = None
|
||||
|
||||
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device, disable_mmap=True)
|
||||
t5xxl.eval()
|
||||
|
||||
# if is_fp8(clip_l_dtype):
|
||||
@@ -471,7 +471,9 @@ if __name__ == "__main__":
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
# DiT
|
||||
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
|
||||
is_schnell, model = flux_utils.load_flow_model(
|
||||
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional, Union, Callable, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
|
||||
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
|
||||
def _clean_memory_on_device(device: torch.device):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def synchronize_device(device: torch.device):
|
||||
def _synchronize_device(device: torch.device):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu":
|
||||
@@ -71,19 +86,18 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
@@ -152,12 +166,15 @@ class Offloader:
|
||||
# Gradient tensors
|
||||
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
||||
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
def __init__(
|
||||
self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False
|
||||
):
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
@@ -172,7 +189,9 @@ class ModelOffloader(Offloader):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
def create_backward_hook(
|
||||
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
|
||||
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
@@ -213,8 +232,8 @@ class ModelOffloader(Offloader):
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
_synchronize_device(self.device)
|
||||
_clean_memory_on_device(self.device)
|
||||
|
||||
def wait_for_block(self, block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import functools
|
||||
import gc
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
# intel gpu support for pytorch older than 2.5
|
||||
# ipex is not needed after pytorch 2.5
|
||||
@@ -36,12 +38,15 @@ def clean_memory():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def clean_memory_on_device(device: torch.device):
|
||||
def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
if device is None:
|
||||
return
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
@@ -51,6 +56,19 @@ def clean_memory_on_device(device: torch.device):
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def synchronize_device(device: Optional[Union[str, torch.device]]):
|
||||
if device is None:
|
||||
return
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_preferred_device() -> torch.device:
|
||||
r"""
|
||||
|
||||
@@ -16,10 +16,11 @@ from safetensors.torch import save_file
|
||||
|
||||
from library import flux_models, flux_utils, strategy_base, train_util
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.safetensors_utils import mem_eff_save_file
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging, mem_eff_save_file
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models
|
||||
from library.utils import load_safetensors
|
||||
from library.safetensors_utils import load_safetensors
|
||||
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
@@ -124,7 +124,7 @@ def load_flow_model(
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = {}
|
||||
for ckpt_path in ckpt_paths:
|
||||
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
|
||||
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
|
||||
@@ -18,10 +18,11 @@ from library import lumina_models, strategy_base, strategy_lumina, train_util
|
||||
from library.flux_models import AutoEncoder
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.safetensors_utils import mem_eff_save_file
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging, mem_eff_save_file
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import Gemma2Config, Gemma2Model
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library import lumina_models, flux_models
|
||||
from library.utils import load_safetensors
|
||||
from library.safetensors_utils import load_safetensors
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
352
library/safetensors_utils.py
Normal file
352
library/safetensors_utils.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
import struct
|
||||
from typing import Dict, Any, Union, Optional
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library.device_utils import synchronize_device
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
"""
|
||||
|
||||
_TYPES = {
|
||||
torch.float64: "F64",
|
||||
torch.float32: "F32",
|
||||
torch.float16: "F16",
|
||||
torch.bfloat16: "BF16",
|
||||
torch.int64: "I64",
|
||||
torch.int32: "I32",
|
||||
torch.int16: "I16",
|
||||
torch.int8: "I8",
|
||||
torch.uint8: "U8",
|
||||
torch.bool: "BOOL",
|
||||
getattr(torch, "float8_e5m2", None): "F8_E5M2",
|
||||
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
|
||||
}
|
||||
_ALIGN = 256
|
||||
|
||||
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
|
||||
validated = {}
|
||||
for key, value in metadata.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata key must be a string, got {type(key)}")
|
||||
if not isinstance(value, str):
|
||||
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
|
||||
validated[key] = str(value)
|
||||
else:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
# print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
if metadata:
|
||||
header["__metadata__"] = validate_metadata(metadata)
|
||||
for k, v in tensors.items():
|
||||
if v.numel() == 0: # empty tensor
|
||||
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
|
||||
else:
|
||||
size = v.numel() * v.element_size()
|
||||
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
|
||||
offset += size
|
||||
|
||||
hjson = json.dumps(header).encode("utf-8")
|
||||
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
f.write(struct.pack("<Q", len(hjson)))
|
||||
f.write(hjson)
|
||||
|
||||
for k, v in tensors.items():
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
if v.is_cuda:
|
||||
# Direct GPU to disk save
|
||||
with torch.cuda.device(v.device):
|
||||
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||
v = v.unsqueeze(0)
|
||||
tensor_bytes = v.contiguous().view(torch.uint8)
|
||||
tensor_bytes.cpu().numpy().tofile(f)
|
||||
else:
|
||||
# CPU tensor save
|
||||
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||
v = v.unsqueeze(0)
|
||||
v.contiguous().view(torch.uint8).numpy().tofile(f)
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
"""Memory-efficient reader for safetensors files.
|
||||
|
||||
This class provides a memory-efficient way to read tensors from safetensors files
|
||||
by using memory mapping for large tensors and avoiding unnecessary copies.
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
"""Initialize the SafeTensor reader.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the safetensors file to read.
|
||||
"""
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit context manager and close file."""
|
||||
self.file.close()
|
||||
|
||||
def keys(self):
|
||||
"""Get all tensor keys in the file.
|
||||
|
||||
Returns:
|
||||
list: List of tensor names (excludes metadata).
|
||||
"""
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
"""Get metadata from the file.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Metadata dictionary.
|
||||
"""
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def _read_header(self):
|
||||
"""Read and parse the header from the safetensors file.
|
||||
|
||||
Returns:
|
||||
tuple: (header_dict, header_size) containing parsed header and its size.
|
||||
"""
|
||||
# Read header size (8 bytes, little-endian unsigned long long)
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
# Read and decode header JSON
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def get_tensor(self, key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
"""Load a tensor from the file with memory-efficient strategies.
|
||||
|
||||
**Note:**
|
||||
If device is 'cuda' , the transfer to GPU is done efficiently using pinned memory and non-blocking transfer.
|
||||
So you must ensure that the transfer is completed before using the tensor (e.g., by `torch.cuda.synchronize()`).
|
||||
|
||||
If the tensor is large (>10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies.
|
||||
|
||||
Args:
|
||||
key (str): Name of the tensor to load.
|
||||
device (Optional[torch.device]): Target device for the tensor.
|
||||
dtype (Optional[torch.dtype]): Target dtype for the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor.
|
||||
|
||||
Raises:
|
||||
KeyError: If the tensor key is not found in the file.
|
||||
"""
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
|
||||
metadata = self.header[key]
|
||||
offset_start, offset_end = metadata["data_offsets"]
|
||||
num_bytes = offset_end - offset_start
|
||||
|
||||
original_dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
target_dtype = dtype if dtype is not None else original_dtype
|
||||
|
||||
# Handle empty tensors
|
||||
if num_bytes == 0:
|
||||
return torch.empty(metadata["shape"], dtype=target_dtype, device=device)
|
||||
|
||||
# Determine if we should use pinned memory for GPU transfer
|
||||
non_blocking = device is not None and device.type == "cuda"
|
||||
|
||||
# Calculate absolute file offset
|
||||
tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size
|
||||
|
||||
# Memory mapping strategy for large tensors to GPU
|
||||
# Use memmap for large tensors to avoid intermediate copies.
|
||||
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
|
||||
# So we only use memmap if device is not cpu.
|
||||
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# Create memory map for zero-copy reading
|
||||
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
|
||||
byte_tensor = torch.from_numpy(mm) # zero copy
|
||||
del mm
|
||||
|
||||
# Deserialize tensor (view and reshape)
|
||||
cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape
|
||||
del byte_tensor
|
||||
|
||||
# Transfer to target device and dtype
|
||||
gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
del cpu_tensor
|
||||
return gpu_tensor
|
||||
|
||||
# Standard file reading strategy for smaller tensors or CPU target
|
||||
# seek to the specified position
|
||||
self.file.seek(tensor_offset)
|
||||
|
||||
# read directly into a numpy array by numpy.fromfile without intermediate copy
|
||||
numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes)
|
||||
byte_tensor = torch.from_numpy(numpy_array)
|
||||
del numpy_array
|
||||
|
||||
# deserialize (view and reshape)
|
||||
deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata)
|
||||
del byte_tensor
|
||||
|
||||
# cast to target dtype and move to device
|
||||
return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict):
|
||||
"""Deserialize byte tensor to the correct shape and dtype.
|
||||
|
||||
Args:
|
||||
byte_tensor (torch.Tensor): Raw byte tensor from file.
|
||||
metadata (Dict): Tensor metadata containing dtype and shape info.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Deserialized tensor with correct shape and dtype.
|
||||
"""
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
shape = metadata["shape"]
|
||||
|
||||
# Handle special float8 types
|
||||
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
|
||||
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
|
||||
|
||||
# Standard conversion: view as target dtype and reshape
|
||||
return byte_tensor.view(dtype).reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_str):
|
||||
"""Convert string dtype to PyTorch dtype.
|
||||
|
||||
Args:
|
||||
dtype_str (str): String representation of the dtype.
|
||||
|
||||
Returns:
|
||||
torch.dtype: Corresponding PyTorch dtype.
|
||||
"""
|
||||
# Standard dtype mappings
|
||||
dtype_map = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
}
|
||||
# Add float8 types if available in PyTorch version
|
||||
if hasattr(torch, "float8_e5m2"):
|
||||
dtype_map["F8_E5M2"] = torch.float8_e5m2
|
||||
if hasattr(torch, "float8_e4m3fn"):
|
||||
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
|
||||
return dtype_map.get(dtype_str)
|
||||
|
||||
@staticmethod
|
||||
def _convert_float8(byte_tensor, dtype_str, shape):
|
||||
"""Convert byte tensor to float8 format if supported.
|
||||
|
||||
Args:
|
||||
byte_tensor (torch.Tensor): Raw byte tensor.
|
||||
dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3").
|
||||
shape (tuple): Target tensor shape.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with float8 dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If float8 type is not supported in current PyTorch version.
|
||||
"""
|
||||
# Convert to specific float8 types if available
|
||||
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
|
||||
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
|
||||
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
|
||||
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
|
||||
else:
|
||||
# Float8 not supported in this PyTorch version
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
device = torch.device(device) if device is not None else None
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
|
||||
synchronize_device(device)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(path, device=device)
|
||||
except:
|
||||
state_dict = load_file(path) # prevent device invalid Error
|
||||
if dtype is not None:
|
||||
for key in state_dict.keys():
|
||||
state_dict[key] = state_dict[key].to(dtype=dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_split_weights(
|
||||
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix.
|
||||
dtype is as is, no conversion is done.
|
||||
"""
|
||||
device = torch.device(device)
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
else:
|
||||
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Find a key in a safetensors file that starts with `starts_with` and ends with `ends_with`.
|
||||
If `starts_with` is None, it will match any key.
|
||||
If `ends_with` is None, it will match any key.
|
||||
Returns the first matching key or None if no key matches.
|
||||
"""
|
||||
with MemoryEfficientSafeOpen(safetensors_file) as f:
|
||||
for key in f.keys():
|
||||
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
|
||||
return key
|
||||
return None
|
||||
@@ -23,7 +23,7 @@ from library import sdxl_model_util
|
||||
# region models
|
||||
|
||||
# TODO remove dependency on flux_utils
|
||||
from library.utils import load_safetensors
|
||||
from library.safetensors_utils import load_safetensors
|
||||
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ def load_vae(
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE from {vae_path}...")
|
||||
vae_sd = load_safetensors(vae_path, device, disable_mmap)
|
||||
vae_sd = load_safetensors(vae_path, device, disable_mmap, dtype=vae_dtype)
|
||||
else:
|
||||
# remove prefix "first_stage_model."
|
||||
vae_sd = {}
|
||||
|
||||
205
library/utils.py
205
library/utils.py
@@ -2,8 +2,6 @@ import logging
|
||||
import sys
|
||||
import threading
|
||||
from typing import *
|
||||
import json
|
||||
import struct
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,7 +12,7 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
@@ -88,6 +86,7 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -190,190 +189,6 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
"""
|
||||
|
||||
_TYPES = {
|
||||
torch.float64: "F64",
|
||||
torch.float32: "F32",
|
||||
torch.float16: "F16",
|
||||
torch.bfloat16: "BF16",
|
||||
torch.int64: "I64",
|
||||
torch.int32: "I32",
|
||||
torch.int16: "I16",
|
||||
torch.int8: "I8",
|
||||
torch.uint8: "U8",
|
||||
torch.bool: "BOOL",
|
||||
getattr(torch, "float8_e5m2", None): "F8_E5M2",
|
||||
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
|
||||
}
|
||||
_ALIGN = 256
|
||||
|
||||
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
|
||||
validated = {}
|
||||
for key, value in metadata.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata key must be a string, got {type(key)}")
|
||||
if not isinstance(value, str):
|
||||
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
|
||||
validated[key] = str(value)
|
||||
else:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
if metadata:
|
||||
header["__metadata__"] = validate_metadata(metadata)
|
||||
for k, v in tensors.items():
|
||||
if v.numel() == 0: # empty tensor
|
||||
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
|
||||
else:
|
||||
size = v.numel() * v.element_size()
|
||||
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
|
||||
offset += size
|
||||
|
||||
hjson = json.dumps(header).encode("utf-8")
|
||||
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
f.write(struct.pack("<Q", len(hjson)))
|
||||
f.write(hjson)
|
||||
|
||||
for k, v in tensors.items():
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
if v.is_cuda:
|
||||
# Direct GPU to disk save
|
||||
with torch.cuda.device(v.device):
|
||||
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||
v = v.unsqueeze(0)
|
||||
tensor_bytes = v.contiguous().view(torch.uint8)
|
||||
tensor_bytes.cpu().numpy().tofile(f)
|
||||
else:
|
||||
# CPU tensor save
|
||||
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||
v = v.unsqueeze(0)
|
||||
v.contiguous().view(torch.uint8).numpy().tofile(f)
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.file.close()
|
||||
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
|
||||
metadata = self.header[key]
|
||||
offset_start, offset_end = metadata["data_offsets"]
|
||||
|
||||
if offset_start == offset_end:
|
||||
tensor_bytes = None
|
||||
else:
|
||||
# adjust offset by header size
|
||||
self.file.seek(self.header_size + 8 + offset_start)
|
||||
tensor_bytes = self.file.read(offset_end - offset_start)
|
||||
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
shape = metadata["shape"]
|
||||
|
||||
if tensor_bytes is None:
|
||||
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
||||
else:
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
||||
|
||||
# process float8 types
|
||||
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
|
||||
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
|
||||
|
||||
# convert to the target dtype and reshape
|
||||
return byte_tensor.view(dtype).reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_str):
|
||||
dtype_map = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
}
|
||||
# add float8 types if available
|
||||
if hasattr(torch, "float8_e5m2"):
|
||||
dtype_map["F8_E5M2"] = torch.float8_e5m2
|
||||
if hasattr(torch, "float8_e4m3fn"):
|
||||
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
|
||||
return dtype_map.get(dtype_str)
|
||||
|
||||
@staticmethod
|
||||
def _convert_float8(byte_tensor, dtype_str, shape):
|
||||
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
|
||||
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
|
||||
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
|
||||
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
|
||||
else:
|
||||
# # convert to float16 if float8 is not supported
|
||||
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(path, device=device)
|
||||
except:
|
||||
state_dict = load_file(path) # prevent device invalid Error
|
||||
if dtype is not None:
|
||||
for key in state_dict.keys():
|
||||
state_dict[key] = state_dict[key].to(dtype=dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image utils
|
||||
@@ -398,7 +213,14 @@ def pil_resize(image, size, interpolation):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
width: int,
|
||||
height: int,
|
||||
resized_width: int,
|
||||
resized_height: int,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
@@ -472,6 +294,7 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
@@ -503,12 +326,14 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
@@ -642,7 +467,9 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
else:
|
||||
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
|
||||
@@ -9,7 +9,8 @@ from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -28,7 +28,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_sd3
|
||||
from library.utils import load_safetensors
|
||||
from library.safetensors_utils import load_safetensors
|
||||
|
||||
|
||||
def get_noise(seed, latent, device="cpu"):
|
||||
|
||||
@@ -14,6 +14,7 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from library import utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.safetensors_utils import load_safetensors
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -206,7 +207,7 @@ def train(args):
|
||||
# t5xxl_dtype = weight_dtype
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx)
|
||||
if args.clip_l is None:
|
||||
sd3_state_dict = utils.load_safetensors(
|
||||
sd3_state_dict = load_safetensors(
|
||||
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
|
||||
)
|
||||
else:
|
||||
@@ -322,7 +323,7 @@ def train(args):
|
||||
# load VAE for caching latents
|
||||
if sd3_state_dict is None:
|
||||
logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}")
|
||||
sd3_state_dict = utils.load_safetensors(
|
||||
sd3_state_dict = load_safetensors(
|
||||
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from library import sd3_models, strategy_sd3, utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.safetensors_utils import load_safetensors
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -77,7 +78,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
state_dict = utils.load_safetensors(
|
||||
state_dict = load_safetensors(
|
||||
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
|
||||
)
|
||||
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from library.custom_offloading_utils import (
|
||||
synchronize_device,
|
||||
_synchronize_device,
|
||||
swap_weight_devices_cuda,
|
||||
swap_weight_devices_no_cuda,
|
||||
weighs_to_device,
|
||||
@@ -50,21 +50,21 @@ class SimpleModel(nn.Module):
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_synchronize(mock_cuda_sync):
|
||||
device = torch.device('cuda')
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
mock_cuda_sync.assert_called_once()
|
||||
|
||||
@patch('torch.xpu.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available")
|
||||
def test_xpu_synchronize(mock_xpu_sync):
|
||||
device = torch.device('xpu')
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
mock_xpu_sync.assert_called_once()
|
||||
|
||||
@patch('torch.mps.synchronize')
|
||||
@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available")
|
||||
def test_mps_synchronize(mock_mps_sync):
|
||||
device = torch.device('mps')
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
mock_mps_sync.assert_called_once()
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ def test_swap_weight_devices_cuda():
|
||||
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
@patch('library.custom_offloading_utils._synchronize_device')
|
||||
def test_swap_weight_devices_no_cuda(mock_sync_device):
|
||||
device = torch.device('cpu')
|
||||
layer_to_cpu = SimpleModel()
|
||||
@@ -121,7 +121,7 @@ def test_swap_weight_devices_no_cuda(mock_sync_device):
|
||||
with patch('torch.Tensor.copy_'):
|
||||
swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda)
|
||||
|
||||
# Verify synchronize_device was called twice
|
||||
# Verify _synchronize_device was called twice
|
||||
assert mock_sync_device.call_count == 2
|
||||
|
||||
|
||||
@@ -279,8 +279,8 @@ def test_backward_hook_execution(mock_wait, mock_submit):
|
||||
|
||||
|
||||
@patch('library.custom_offloading_utils.weighs_to_device')
|
||||
@patch('library.custom_offloading_utils.synchronize_device')
|
||||
@patch('library.custom_offloading_utils.clean_memory_on_device')
|
||||
@patch('library.custom_offloading_utils._synchronize_device')
|
||||
@patch('library.custom_offloading_utils._clean_memory_on_device')
|
||||
def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader):
|
||||
model = SimpleModel(4)
|
||||
blocks = model.blocks
|
||||
@@ -291,7 +291,7 @@ def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weight
|
||||
# Check that weighs_to_device was called for each block
|
||||
assert mock_weights_to_device.call_count == 4
|
||||
|
||||
# Check that synchronize_device and clean_memory_on_device were called
|
||||
# Check that _synchronize_device and _clean_memory_on_device were called
|
||||
mock_sync.assert_called_once_with(model_offloader.device)
|
||||
mock_clean.assert_called_once_with(model_offloader.device)
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import flux_utils
|
||||
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -6,7 +6,8 @@ import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
|
||||
from library.utils import str_to_dtype
|
||||
from library.safetensors_utils import load_safetensors, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
Reference in New Issue
Block a user