fix: block-wise scaling is overwritten by per-tensor scaling

This commit is contained in:
Kohya S
2025-09-21 13:10:41 +09:00
parent 3876343fad
commit 806d535ef1

View File

@@ -220,10 +220,6 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# Calculate scale factor
scale = torch.max(torch.abs(tensor.flatten())) / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division