feat: block-wise fp8 quantization

This commit is contained in:
Kohya S
2025-09-18 21:20:54 +09:00
parent 2ce506e187
commit f6b4bdc83f
4 changed files with 186 additions and 102 deletions

View File

@@ -1,5 +1,5 @@
import os
from typing import List, Union
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
"""
Calculate the maximum representable value in FP8 format.
Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign).
Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). Only supports E4M3 and E5M2 with sign bit.
Args:
exp_bits (int): Number of exponent bits
@@ -32,73 +32,73 @@ def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
float: Maximum value representable in FP8 format
"""
assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
# Calculate exponent bias
bias = 2 ** (exp_bits - 1) - 1
# Calculate maximum mantissa value
mantissa_max = 1.0
for i in range(mantissa_bits - 1):
mantissa_max += 2 ** -(i + 1)
# Calculate maximum value
max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
return max_value
if exp_bits == 4 and mantissa_bits == 3 and sign_bits == 1:
return torch.finfo(torch.float8_e4m3fn).max
elif exp_bits == 5 and mantissa_bits == 2 and sign_bits == 1:
return torch.finfo(torch.float8_e5m2).max
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits} with sign_bits={sign_bits}")
def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None):
# The following is a manual calculation method (wrong implementation for E5M2), kept for reference.
"""
# Calculate exponent bias
bias = 2 ** (exp_bits - 1) - 1
# Calculate maximum mantissa value
mantissa_max = 1.0
for i in range(mantissa_bits - 1):
mantissa_max += 2 ** -(i + 1)
# Calculate maximum value
max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
return max_value
"""
def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value):
"""
Quantize a tensor to FP8 format.
Quantize a tensor to FP8 format using PyTorch's native FP8 dtype support.
Args:
tensor (torch.Tensor): Tensor to quantize
scale (float or torch.Tensor): Scale factor
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
sign_bits (int): Number of sign bits
fp8_dtype (torch.dtype): Target FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
max_value (float): Maximum representable value in FP8
min_value (float): Minimum representable value in FP8
Returns:
tuple: (quantized_tensor, scale_factor)
torch.Tensor: Quantized tensor in FP8 format
"""
tensor = tensor.to(torch.float32) # ensure tensor is in float32 for division
# Create scaled tensor
scaled_tensor = tensor / scale
# Calculate FP8 parameters
bias = 2 ** (exp_bits - 1) - 1
if max_value is None:
# Calculate max and min values
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits)
min_value = -max_value if sign_bits > 0 else 0.0
tensor = torch.div(tensor, scale).nan_to_num_(0.0) # handle NaN values, equivalent to nonzero_mask in previous function
# Clamp tensor to range
clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value)
tensor = tensor.clamp_(min=min_value, max=max_value)
# Quantization process
abs_values = torch.abs(clamped_tensor)
nonzero_mask = abs_values > 0
# Convert to FP8 dtype
tensor = tensor.to(fp8_dtype)
# Calculate log scales (only for non-zero elements)
log_scales = torch.zeros_like(clamped_tensor)
if nonzero_mask.any():
log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach()
# Limit log scales and calculate quantization factor
log_scales = torch.clamp(log_scales, min=1.0)
quant_factor = 2.0 ** (log_scales - mantissa_bits - bias)
# Quantize and dequantize
quantized = torch.round(clamped_tensor / quant_factor) * quant_factor
return quantized, scale
return tensor
def optimize_state_dict_with_fp8(
state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False
state_dict: dict,
calc_device: Union[str, torch.device],
target_layer_keys: Optional[list[str]] = None,
exclude_layer_keys: Optional[list[str]] = None,
exp_bits: int = 4,
mantissa_bits: int = 3,
move_to_device: bool = False,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
):
"""
Optimize Linear layer weights in a model's state dict to FP8 format.
Optimize Linear layer weights in a model's state dict to FP8 format. The state dict is modified in-place.
This function is a static version of load_safetensors_with_fp8_optimization without loading from files.
Args:
state_dict (dict): State dict to optimize, replaced in-place
@@ -149,23 +149,17 @@ def optimize_state_dict_with_fp8(
if calc_device is not None:
value = value.to(calc_device)
# Calculate scale factor
scale = torch.max(torch.abs(value.flatten())) / max_value
# print(f"Optimizing {key} with scale: {scale}")
# Quantize weight to FP8
quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size)
# Add to state dict using original key for weight and new key for scale
fp8_key = key # Maintain original key
scale_key = key.replace(".weight", ".scale_weight")
quantized_weight = quantized_weight.to(fp8_dtype)
if not move_to_device:
quantized_weight = quantized_weight.to(original_device)
scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
# keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model.
scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device)
state_dict[fp8_key] = quantized_weight
state_dict[scale_key] = scale_tensor
@@ -180,6 +174,70 @@ def optimize_state_dict_with_fp8(
return state_dict
def quantize_weight(
key: str,
tensor: torch.Tensor,
fp8_dtype: torch.dtype,
max_value: float,
min_value: float,
quantization_mode: str = "block",
block_size: int = 64,
):
original_shape = tensor.shape
# Determine quantization mode
if quantization_mode == "block":
if tensor.ndim != 2:
quantization_mode = "tensor" # fallback to per-tensor
else:
out_features, in_features = tensor.shape
if in_features % block_size != 0:
quantization_mode = "channel" # fallback to per-channel
logger.warning(
f"Layer {key} with shape {tensor.shape} is not divisible by block_size {block_size}, fallback to per-channel quantization."
)
else:
num_blocks = in_features // block_size
tensor = tensor.contiguous().view(out_features, num_blocks, block_size) # [out, num_blocks, block_size]
elif quantization_mode == "channel":
if tensor.ndim != 2:
quantization_mode = "tensor" # fallback to per-tensor
# Calculate scale factor (per-tensor or per-output-channel with percentile or max)
# value shape is expected to be [out_features, in_features] for Linear weights
if quantization_mode == "channel" or quantization_mode == "block":
# row-wise percentile to avoid being dominated by outliers
# result shape: [out_features, 1] or [out_features, num_blocks, 1]
scale_dim = 1 if quantization_mode == "channel" else 2
abs_w = torch.abs(tensor)
# shape: [out_features, 1] or [out_features, num_blocks, 1]
row_max = torch.max(abs_w, dim=scale_dim, keepdim=True).values
scale = row_max / max_value
else:
# per-tensor
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# Calculate scale factor
scale = torch.max(torch.abs(tensor.flatten())) / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
# Quantize weight to FP8 (scale can be scalar or [out,1], broadcasting works)
quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value)
# If block-wise, restore original shape
if quantization_mode == "block":
quantized_weight = quantized_weight.view(original_shape) # restore to original shape [out, in]
return quantized_weight, scale
def load_safetensors_with_fp8_optimization(
model_files: List[str],
calc_device: Union[str, torch.device],
@@ -189,7 +247,9 @@ def load_safetensors_with_fp8_optimization(
mantissa_bits=3,
move_to_device=False,
weight_hook=None,
):
quantization_mode: str = "block",
block_size: Optional[int] = 64,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@@ -202,6 +262,8 @@ def load_safetensors_with_fp8_optimization(
mantissa_bits (int): Number of mantissa bits
move_to_device (bool): Move optimized tensors to the calculating device
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
Returns:
dict: FP8 optimized state dict
@@ -234,40 +296,39 @@ def load_safetensors_with_fp8_optimization(
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
# Save original device
original_device = value.device # usually cpu
if weight_hook is not None:
# Apply weight hook if provided
value = weight_hook(key, value)
value = weight_hook(key, value, keep_on_calc_device=(calc_device is not None))
if not is_target_key(key):
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Save original device and dtype
original_device = value.device
original_dtype = value.dtype
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
# Calculate scale factor
scale = torch.max(torch.abs(value.flatten())) / max_value
# print(f"Optimizing {key} with scale: {scale}")
# Quantize weight to FP8
quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
# Add to state dict using original key for weight and new key for scale
fp8_key = key # Maintain original key
scale_key = key.replace(".weight", ".scale_weight")
assert fp8_key != scale_key, "FP8 key and scale key must be different"
quantized_weight = quantized_weight.to(fp8_dtype)
if not move_to_device:
quantized_weight = quantized_weight.to(original_device)
scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
# keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model.
scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device)
state_dict[fp8_key] = quantized_weight
state_dict[scale_key] = scale_tensor
@@ -296,12 +357,15 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
torch.Tensor: Result of linear transformation
"""
if use_scaled_mm:
# **not tested**
# _scaled_mm only works for per-tensor scale for now (per-channel scale does not work in certain cases)
if self.scale_weight.ndim != 1:
raise ValueError("scaled_mm only supports per-tensor scale_weight for now.")
input_dtype = x.dtype
original_weight_dtype = self.scale_weight.dtype
weight_dtype = self.weight.dtype
target_dtype = torch.float8_e5m2
assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported"
assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
target_dtype = self.weight.dtype
# assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
if max_value is None:
# no input quantization
@@ -311,10 +375,12 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
# quantize input tensor to FP8: this seems to consume a lot of memory
x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value)
fp8_max_value = torch.finfo(target_dtype).max
fp8_min_value = torch.finfo(target_dtype).min
x = quantize_fp8(x, scale_x, target_dtype, fp8_max_value, fp8_min_value)
original_shape = x.shape
x = x.reshape(-1, x.shape[2]).to(target_dtype)
x = x.reshape(-1, x.shape[-1]).to(target_dtype)
weight = self.weight.t()
scale_weight = self.scale_weight.to(torch.float32)
@@ -325,12 +391,21 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:
# Dequantize the weight
original_dtype = self.scale_weight.dtype
dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
if self.scale_weight.ndim < 3:
# per-tensor or per-channel quantization, we can broadcast
dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
else:
# block-wise quantization, need to reshape weight to match scale shape for broadcasting
out_features, num_blocks, _ = self.scale_weight.shape
dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_features, num_blocks, -1)
dequantized_weight = dequantized_weight * self.scale_weight
dequantized_weight = dequantized_weight.view(self.weight.shape)
# Perform linear transformation
if self.bias is not None:
@@ -362,11 +437,15 @@ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
# Enumerate patched layers
patched_module_paths = set()
scale_shape_info = {}
for scale_key in scale_keys:
# Extract module path from scale key (remove .scale_weight)
module_path = scale_key.rsplit(".scale_weight", 1)[0]
patched_module_paths.add(module_path)
# Store scale shape information
scale_shape_info[module_path] = optimized_state_dict[scale_key].shape
patched_count = 0
# Apply monkey patch to each layer with FP8 weights
@@ -377,7 +456,9 @@ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
# Apply patch if it's a Linear layer with FP8 scale
if isinstance(module, nn.Linear) and has_scale:
# register the scale_weight as a buffer to load the state_dict
module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
# module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
scale_shape = scale_shape_info[name]
module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype))
# Create a new forward method with the patched version.
def new_forward(self, x):

View File

@@ -30,7 +30,12 @@ from library.hunyuan_image_modules import (
from library.hunyuan_image_utils import get_nd_rotary_pos_embed
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "modulation", "_emb"]
# FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "_emb"] # , "modulation"
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_emb"] # , "modulation", "_mod"
# full exclude 24.2GB
# norm and _emb 19.7GB
# fp8 cast 19.7GB
# region DiT Model

View File

@@ -497,7 +497,9 @@ class RMSNorm(nn.Module):
"""
output = self._norm(x.float()).type_as(x)
del x
output = output * self.weight
# output = output * self.weight
# fp8 support
output = output * self.weight.to(output.dtype)
return output
@@ -689,7 +691,7 @@ class MMDoubleStreamBlock(nn.Module):
del qkv
# Split attention outputs back to separate streams
img_attn, txt_attn = (attn[:, : img_seq_len].contiguous(), attn[:, img_seq_len :].contiguous())
img_attn, txt_attn = (attn[:, :img_seq_len].contiguous(), attn[:, img_seq_len:].contiguous())
del attn
# Apply attention projection and residual connection for image stream

View File

@@ -1,12 +1,8 @@
# copy from Musubi Tuner
import os
import re
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
@@ -84,7 +80,7 @@ def load_safetensors_with_lora_and_fp8(
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
@@ -118,7 +114,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight):
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@@ -176,7 +172,8 @@ def load_safetensors_with_lora_and_fp8(
if alpha_key in lora_weight_keys:
lora_weight_keys.remove(alpha_key)
model_weight = model_weight.to(original_device) # move back to original device
if not keep_on_calc_device and original_device != calc_device:
model_weight = model_weight.to(original_device) # move back to original device
return model_weight
weight_hook = weight_hook_func
@@ -231,19 +228,18 @@ def load_safetensors_with_fp8_optimization_and_hook(
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
value = f.get_tensor(key)
if weight_hook is not None:
value = weight_hook(key, value)
if move_to_device:
if dit_weight_dtype is None:
value = value.to(calc_device, non_blocking=True)
else:
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
else:
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
if weight_hook is not None:
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
if move_to_device:
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
state_dict[key] = value
if move_to_device:
synchronize_device(calc_device)