From f6b4bdc83fc2c290db4788ac0062f2728fb1e618 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 18 Sep 2025 21:20:54 +0900 Subject: [PATCH] feat: block-wise fp8 quantization --- library/fp8_optimization_utils.py | 245 ++++++++++++++++++++---------- library/hunyuan_image_models.py | 7 +- library/hunyuan_image_modules.py | 6 +- library/lora_utils.py | 30 ++-- 4 files changed, 186 insertions(+), 102 deletions(-) diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py index ed7d3f76..82ec6bfc 100644 --- a/library/fp8_optimization_utils.py +++ b/library/fp8_optimization_utils.py @@ -1,5 +1,5 @@ import os -from typing import List, Union +from typing import List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): """ Calculate the maximum representable value in FP8 format. - Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). + Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). Only supports E4M3 and E5M2 with sign bit. Args: exp_bits (int): Number of exponent bits @@ -32,73 +32,73 @@ def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): float: Maximum value representable in FP8 format """ assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8" - - # Calculate exponent bias - bias = 2 ** (exp_bits - 1) - 1 - - # Calculate maximum mantissa value - mantissa_max = 1.0 - for i in range(mantissa_bits - 1): - mantissa_max += 2 ** -(i + 1) - - # Calculate maximum value - max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) - - return max_value + if exp_bits == 4 and mantissa_bits == 3 and sign_bits == 1: + return torch.finfo(torch.float8_e4m3fn).max + elif exp_bits == 5 and mantissa_bits == 2 and sign_bits == 1: + return torch.finfo(torch.float8_e5m2).max + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits} with sign_bits={sign_bits}") -def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None): +# The following is a manual calculation method (wrong implementation for E5M2), kept for reference. +""" +# Calculate exponent bias +bias = 2 ** (exp_bits - 1) - 1 + +# Calculate maximum mantissa value +mantissa_max = 1.0 +for i in range(mantissa_bits - 1): + mantissa_max += 2 ** -(i + 1) + +# Calculate maximum value +max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) + +return max_value +""" + + +def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value): """ - Quantize a tensor to FP8 format. + Quantize a tensor to FP8 format using PyTorch's native FP8 dtype support. Args: tensor (torch.Tensor): Tensor to quantize scale (float or torch.Tensor): Scale factor - exp_bits (int): Number of exponent bits - mantissa_bits (int): Number of mantissa bits - sign_bits (int): Number of sign bits + fp8_dtype (torch.dtype): Target FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2) + max_value (float): Maximum representable value in FP8 + min_value (float): Minimum representable value in FP8 Returns: - tuple: (quantized_tensor, scale_factor) + torch.Tensor: Quantized tensor in FP8 format """ + tensor = tensor.to(torch.float32) # ensure tensor is in float32 for division + # Create scaled tensor - scaled_tensor = tensor / scale - - # Calculate FP8 parameters - bias = 2 ** (exp_bits - 1) - 1 - - if max_value is None: - # Calculate max and min values - max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits) - min_value = -max_value if sign_bits > 0 else 0.0 + tensor = torch.div(tensor, scale).nan_to_num_(0.0) # handle NaN values, equivalent to nonzero_mask in previous function # Clamp tensor to range - clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value) + tensor = tensor.clamp_(min=min_value, max=max_value) - # Quantization process - abs_values = torch.abs(clamped_tensor) - nonzero_mask = abs_values > 0 + # Convert to FP8 dtype + tensor = tensor.to(fp8_dtype) - # Calculate log scales (only for non-zero elements) - log_scales = torch.zeros_like(clamped_tensor) - if nonzero_mask.any(): - log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach() - - # Limit log scales and calculate quantization factor - log_scales = torch.clamp(log_scales, min=1.0) - quant_factor = 2.0 ** (log_scales - mantissa_bits - bias) - - # Quantize and dequantize - quantized = torch.round(clamped_tensor / quant_factor) * quant_factor - - return quantized, scale + return tensor def optimize_state_dict_with_fp8( - state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False + state_dict: dict, + calc_device: Union[str, torch.device], + target_layer_keys: Optional[list[str]] = None, + exclude_layer_keys: Optional[list[str]] = None, + exp_bits: int = 4, + mantissa_bits: int = 3, + move_to_device: bool = False, + quantization_mode: str = "block", + block_size: Optional[int] = 64, ): """ - Optimize Linear layer weights in a model's state dict to FP8 format. + Optimize Linear layer weights in a model's state dict to FP8 format. The state dict is modified in-place. + This function is a static version of load_safetensors_with_fp8_optimization without loading from files. Args: state_dict (dict): State dict to optimize, replaced in-place @@ -149,23 +149,17 @@ def optimize_state_dict_with_fp8( if calc_device is not None: value = value.to(calc_device) - # Calculate scale factor - scale = torch.max(torch.abs(value.flatten())) / max_value - # print(f"Optimizing {key} with scale: {scale}") - - # Quantize weight to FP8 - quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size) # Add to state dict using original key for weight and new key for scale fp8_key = key # Maintain original key scale_key = key.replace(".weight", ".scale_weight") - quantized_weight = quantized_weight.to(fp8_dtype) - if not move_to_device: quantized_weight = quantized_weight.to(original_device) - scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + # keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model. + scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device) state_dict[fp8_key] = quantized_weight state_dict[scale_key] = scale_tensor @@ -180,6 +174,70 @@ def optimize_state_dict_with_fp8( return state_dict +def quantize_weight( + key: str, + tensor: torch.Tensor, + fp8_dtype: torch.dtype, + max_value: float, + min_value: float, + quantization_mode: str = "block", + block_size: int = 64, +): + original_shape = tensor.shape + + # Determine quantization mode + if quantization_mode == "block": + if tensor.ndim != 2: + quantization_mode = "tensor" # fallback to per-tensor + else: + out_features, in_features = tensor.shape + if in_features % block_size != 0: + quantization_mode = "channel" # fallback to per-channel + logger.warning( + f"Layer {key} with shape {tensor.shape} is not divisible by block_size {block_size}, fallback to per-channel quantization." + ) + else: + num_blocks = in_features // block_size + tensor = tensor.contiguous().view(out_features, num_blocks, block_size) # [out, num_blocks, block_size] + elif quantization_mode == "channel": + if tensor.ndim != 2: + quantization_mode = "tensor" # fallback to per-tensor + + # Calculate scale factor (per-tensor or per-output-channel with percentile or max) + # value shape is expected to be [out_features, in_features] for Linear weights + if quantization_mode == "channel" or quantization_mode == "block": + # row-wise percentile to avoid being dominated by outliers + # result shape: [out_features, 1] or [out_features, num_blocks, 1] + scale_dim = 1 if quantization_mode == "channel" else 2 + abs_w = torch.abs(tensor) + + # shape: [out_features, 1] or [out_features, num_blocks, 1] + row_max = torch.max(abs_w, dim=scale_dim, keepdim=True).values + scale = row_max / max_value + + else: + # per-tensor + tensor_max = torch.max(torch.abs(tensor).view(-1)) + scale = tensor_max / max_value + + # Calculate scale factor + scale = torch.max(torch.abs(tensor.flatten())) / max_value + # print(f"Optimizing {key} with scale: {scale}") + + # numerical safety + scale = torch.clamp(scale, min=1e-8) + scale = scale.to(torch.float32) # ensure scale is in float32 for division + + # Quantize weight to FP8 (scale can be scalar or [out,1], broadcasting works) + quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value) + + # If block-wise, restore original shape + if quantization_mode == "block": + quantized_weight = quantized_weight.view(original_shape) # restore to original shape [out, in] + + return quantized_weight, scale + + def load_safetensors_with_fp8_optimization( model_files: List[str], calc_device: Union[str, torch.device], @@ -189,7 +247,9 @@ def load_safetensors_with_fp8_optimization( mantissa_bits=3, move_to_device=False, weight_hook=None, -): + quantization_mode: str = "block", + block_size: Optional[int] = 64, +) -> dict: """ Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization. @@ -202,6 +262,8 @@ def load_safetensors_with_fp8_optimization( mantissa_bits (int): Number of mantissa bits move_to_device (bool): Move optimized tensors to the calculating device weight_hook (callable, optional): Function to apply to each weight tensor before optimization + quantization_mode (str): Quantization mode, "tensor", "channel", or "block" + block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block") Returns: dict: FP8 optimized state dict @@ -234,40 +296,39 @@ def load_safetensors_with_fp8_optimization( keys = f.keys() for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"): value = f.get_tensor(key) + + # Save original device + original_device = value.device # usually cpu + if weight_hook is not None: # Apply weight hook if provided - value = weight_hook(key, value) + value = weight_hook(key, value, keep_on_calc_device=(calc_device is not None)) if not is_target_key(key): + target_device = calc_device if (calc_device is not None and move_to_device) else original_device + value = value.to(target_device) state_dict[key] = value continue - # Save original device and dtype - original_device = value.device - original_dtype = value.dtype - # Move to calculation device if calc_device is not None: value = value.to(calc_device) - # Calculate scale factor - scale = torch.max(torch.abs(value.flatten())) / max_value - # print(f"Optimizing {key} with scale: {scale}") - - # Quantize weight to FP8 - quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + original_dtype = value.dtype + quantized_weight, scale_tensor = quantize_weight( + key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size + ) # Add to state dict using original key for weight and new key for scale fp8_key = key # Maintain original key scale_key = key.replace(".weight", ".scale_weight") assert fp8_key != scale_key, "FP8 key and scale key must be different" - quantized_weight = quantized_weight.to(fp8_dtype) - if not move_to_device: quantized_weight = quantized_weight.to(original_device) - scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + # keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model. + scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device) state_dict[fp8_key] = quantized_weight state_dict[scale_key] = scale_tensor @@ -296,12 +357,15 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value= torch.Tensor: Result of linear transformation """ if use_scaled_mm: + # **not tested** + # _scaled_mm only works for per-tensor scale for now (per-channel scale does not work in certain cases) + if self.scale_weight.ndim != 1: + raise ValueError("scaled_mm only supports per-tensor scale_weight for now.") + input_dtype = x.dtype original_weight_dtype = self.scale_weight.dtype - weight_dtype = self.weight.dtype - target_dtype = torch.float8_e5m2 - assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported" - assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)" + target_dtype = self.weight.dtype + # assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)" if max_value is None: # no input quantization @@ -311,10 +375,12 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value= scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) # quantize input tensor to FP8: this seems to consume a lot of memory - x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value) + fp8_max_value = torch.finfo(target_dtype).max + fp8_min_value = torch.finfo(target_dtype).min + x = quantize_fp8(x, scale_x, target_dtype, fp8_max_value, fp8_min_value) original_shape = x.shape - x = x.reshape(-1, x.shape[2]).to(target_dtype) + x = x.reshape(-1, x.shape[-1]).to(target_dtype) weight = self.weight.t() scale_weight = self.scale_weight.to(torch.float32) @@ -325,12 +391,21 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value= else: o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) - return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype) + o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1) + return o.to(input_dtype) else: # Dequantize the weight original_dtype = self.scale_weight.dtype - dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + if self.scale_weight.ndim < 3: + # per-tensor or per-channel quantization, we can broadcast + dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + else: + # block-wise quantization, need to reshape weight to match scale shape for broadcasting + out_features, num_blocks, _ = self.scale_weight.shape + dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_features, num_blocks, -1) + dequantized_weight = dequantized_weight * self.scale_weight + dequantized_weight = dequantized_weight.view(self.weight.shape) # Perform linear transformation if self.bias is not None: @@ -362,11 +437,15 @@ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): # Enumerate patched layers patched_module_paths = set() + scale_shape_info = {} for scale_key in scale_keys: # Extract module path from scale key (remove .scale_weight) module_path = scale_key.rsplit(".scale_weight", 1)[0] patched_module_paths.add(module_path) + # Store scale shape information + scale_shape_info[module_path] = optimized_state_dict[scale_key].shape + patched_count = 0 # Apply monkey patch to each layer with FP8 weights @@ -377,7 +456,9 @@ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): # Apply patch if it's a Linear layer with FP8 scale if isinstance(module, nn.Linear) and has_scale: # register the scale_weight as a buffer to load the state_dict - module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + # module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + scale_shape = scale_shape_info[name] + module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype)) # Create a new forward method with the patched version. def new_forward(self, x): diff --git a/library/hunyuan_image_models.py b/library/hunyuan_image_models.py index 2a6092ea..356ce4b4 100644 --- a/library/hunyuan_image_models.py +++ b/library/hunyuan_image_models.py @@ -30,7 +30,12 @@ from library.hunyuan_image_modules import ( from library.hunyuan_image_utils import get_nd_rotary_pos_embed FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"] -FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "modulation", "_emb"] +# FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "_emb"] # , "modulation" +FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_emb"] # , "modulation", "_mod" + +# full exclude 24.2GB +# norm and _emb 19.7GB +# fp8 cast 19.7GB # region DiT Model diff --git a/library/hunyuan_image_modules.py b/library/hunyuan_image_modules.py index ef4d5e5d..555cb487 100644 --- a/library/hunyuan_image_modules.py +++ b/library/hunyuan_image_modules.py @@ -497,7 +497,9 @@ class RMSNorm(nn.Module): """ output = self._norm(x.float()).type_as(x) del x - output = output * self.weight + # output = output * self.weight + # fp8 support + output = output * self.weight.to(output.dtype) return output @@ -689,7 +691,7 @@ class MMDoubleStreamBlock(nn.Module): del qkv # Split attention outputs back to separate streams - img_attn, txt_attn = (attn[:, : img_seq_len].contiguous(), attn[:, img_seq_len :].contiguous()) + img_attn, txt_attn = (attn[:, :img_seq_len].contiguous(), attn[:, img_seq_len:].contiguous()) del attn # Apply attention projection and residual connection for image stream diff --git a/library/lora_utils.py b/library/lora_utils.py index b93eb9af..6f0fc228 100644 --- a/library/lora_utils.py +++ b/library/lora_utils.py @@ -1,12 +1,8 @@ -# copy from Musubi Tuner - import os import re from typing import Dict, List, Optional, Union import torch - from tqdm import tqdm - from library.device_utils import synchronize_device from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization from library.safetensors_utils import MemoryEfficientSafeOpen @@ -84,7 +80,7 @@ def load_safetensors_with_lora_and_fp8( count = int(match.group(3)) state_dict = {} for i in range(count): - filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors" + filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors" filepath = os.path.join(os.path.dirname(model_file), filename) if os.path.exists(filepath): extended_model_files.append(filepath) @@ -118,7 +114,7 @@ def load_safetensors_with_lora_and_fp8( logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") # make hook for LoRA merging - def weight_hook_func(model_weight_key, model_weight): + def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False): nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device if not model_weight_key.endswith(".weight"): @@ -176,7 +172,8 @@ def load_safetensors_with_lora_and_fp8( if alpha_key in lora_weight_keys: lora_weight_keys.remove(alpha_key) - model_weight = model_weight.to(original_device) # move back to original device + if not keep_on_calc_device and original_device != calc_device: + model_weight = model_weight.to(original_device) # move back to original device return model_weight weight_hook = weight_hook_func @@ -231,19 +228,18 @@ def load_safetensors_with_fp8_optimization_and_hook( for model_file in model_files: with MemoryEfficientSafeOpen(model_file) as f: for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): - value = f.get_tensor(key) - if weight_hook is not None: - value = weight_hook(key, value) - if move_to_device: - if dit_weight_dtype is None: - value = value.to(calc_device, non_blocking=True) - else: + if weight_hook is None and move_to_device: + value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) + else: + value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer + if weight_hook is not None: + value = weight_hook(key, value, keep_on_calc_device=move_to_device) + if move_to_device: value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) - elif dit_weight_dtype is not None: - value = value.to(dit_weight_dtype) + elif dit_weight_dtype is not None: + value = value.to(dit_weight_dtype) state_dict[key] = value - if move_to_device: synchronize_device(calc_device)