diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 355e074b..8bfbc41f 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -25,6 +25,7 @@ __all__ = [ "FP8_E4M3_DATA", "FP4_E2M1_DATA", + "BFLOAT16_DATA", "FloatArgs", "QuantizationType", "QuantizationStrategy", @@ -38,9 +39,9 @@ class FloatArgs: exponent: int mantissa: int - bits: int - max: float - min: float + bits: Optional[int] = None + max: Optional[float] = None + min: Optional[float] = None dtype: Optional[torch.dtype] = None @@ -76,6 +77,11 @@ class FP8_E4M3_DATA(FloatArgs): dtype = torch.float8_e4m3fn +class BFLOAT16_DATA(FloatArgs): + exponent = 8 + mantissa = 7 + + class QuantizationType(str, Enum): """ Enum storing quantization type options diff --git a/src/compressed_tensors/quantization/utils/__init__.py b/src/compressed_tensors/quantization/utils/__init__.py index a91f9e5d..0198b374 100644 --- a/src/compressed_tensors/quantization/utils/__init__.py +++ b/src/compressed_tensors/quantization/utils/__init__.py @@ -14,3 +14,4 @@ # flake8: noqa from .helpers import * +from .mxfp4_utils import * diff --git a/src/compressed_tensors/quantization/utils/mxfp4_utils.py b/src/compressed_tensors/quantization/utils/mxfp4_utils.py new file mode 100644 index 00000000..17821ae7 --- /dev/null +++ b/src/compressed_tensors/quantization/utils/mxfp4_utils.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA + + +__all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"] + +# Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501 + + +def convert_mxfp4_exp_scale( + scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16 +) -> torch.Tensor: + """ + Converts mxfp4 scales. Scales are powers of 2, with the + exponents stored in uint8. Converts to dense dtype so that + they can be applied to the weights and activations during QDQ + + :param scale: uint8 exponent scale + :param dtype: dense dtype + """ + assert scale.dtype == torch.uint8 + scale_exp = scale.to(torch.int32) - 127 + scale = 2.00 ** (scale_exp.to(torch.float)) + return scale.to(dtype) + + +def round_to_power_2(x: torch.Tensor) -> torch.Tensor: + """ + Round values to the closest power of 2. + This is done by masking the values with BFLOAT16_SIGN_EXPONENT_MASK + which essentially removes the mantissa and keeps the exponent. + i.e the closest power of 2 for the input_value. + + E.g: + 0.0825 = 1.32 (mantissa) x 2**-4 (exponent) + 0.0825 ==> -4 (exponent) + 127 = 123 = 01111011 (8 bits for bfloat16) + 0.0825 ==> 0.32 (mantissa) = 0101001 (7 bits for bfloat16) + 0.0825 == 0b01111011_0101001 (bfloat16) + 0b01111011_0101001 & 111111111_0000000 == 0b01111011_0000000 + Keep the exponent + sign bit to give you the closest power of 2, 0.0625 + + :param x: tensor to round to closest power of 2 + """ + assert x.dtype == torch.bfloat16 + x = x.view(torch.uint16).to(torch.int32) + + # Find closest power of 2 + BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1) + # Add value to push the value to the next exponent + BFLOAT16_SIGN_EXPONENT_MASK = ( + (1 << (BFLOAT16_DATA.exponent + 1)) - 1 + ) << BFLOAT16_DATA.mantissa + # mask to only keep exponent - we conservatively round down + # to better represent smaller numbers / prevent overflow + block_max_uint = torch.bitwise_and( + x + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK + ) + return block_max_uint.to(torch.uint16).view(torch.bfloat16) + + +def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: + """ + Generate mxfp4 scales. The scales require the following steps + 1. Round to the closest power of 2 + 2. Convert to exponent + 3. Store in uint8 + + Called when calculating qparams using observers. + + :param x: tensor to round to closest power of 2 + :returns uint8 scales as exponents + """ + # Round to closest power of 2 + scale_power_2 = round_to_power_2(x) + # Convert to exponent + scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2 + # Clamp and store in uint8, as expected by mxfp4 + scale_exp = torch.clamp( + scale_exp, + max=torch.iinfo(torch.uint8).max, + min=torch.iinfo(torch.uint8).min, + ) + return scale_exp.to(torch.uint8) diff --git a/tests/test_quantization/test_utils/test_mxfp4_utils.py b/tests/test_quantization/test_utils/test_mxfp4_utils.py new file mode 100644 index 00000000..723228be --- /dev/null +++ b/tests/test_quantization/test_utils/test_mxfp4_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.quantization.utils import ( + convert_mxfp4_exp_scale, + generate_mxfp4_scales, + round_to_power_2, +) + + +def test_round_power_2_noise(): + powers = torch.Tensor( + [ + [2**-10, 2**-9, 2**-8, 2**-7, 2**-6], + [2**-5, 2**-4, 2**-3, 2**-2, 2**-1], + [2**0, 2**1, 2**-10, 2**-9, 2**-8], + [2**-7, 2**-6, 2**-5, 2**-4, 2**-3], + [2**-2, 2**-1, 2**0, 2**1, 2**-10], + ] + ).to(torch.bfloat16) + + noise = torch.rand_like(powers) * 0.2 + powers_noisy = powers * (1 + noise) + rounded = round_to_power_2(powers_noisy) + assert torch.equal(rounded, powers) + + +def test_round_power_2(): + x = torch.Tensor( + ( + [5.687891, -8.291567, -1.540329, -0.315635, 0.965272], + [-6.944130, 0.073246, -0.451778, 8.571118, -9.856593], + [-0.040571, -0.708509, 2.485657, -4.003352, -0.995600], + [0.224199, 5.032586, -1.309816, -0.621958, 7.290238], + [-9.848001, -0.290731, 1.501562, 0.379829, -5.312081], + ) + ).to(torch.bfloat16) + x_rounded = torch.Tensor( + ( + [4.000000, -8.000000, -1.000000, -0.250000, 1.000000], + [-4.000000, 0.062500, -0.500000, 8.000000, -8.000000], + [-0.0312, -0.500000, 2.000000, -4.000000, -1.000000], + [0.250000, 4.000000, -1.000000, -0.500000, 8.000000], + [-8.000000, -0.250000, 1.000000, 0.250000, -4.000000], + ) + ).to(torch.bfloat16) + rounded = round_to_power_2(x) + assert torch.equal(rounded, x_rounded) + + +def test_mxfp4_scales_e2e(): + mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880)) + + x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16) + min_vals = torch.amin(x, dim=-1) + max_vals = torch.amax(x, dim=-1) + + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + + scales_generated = generate_mxfp4_scales(block_max) + converted_ct = convert_mxfp4_exp_scale(scales_generated) + + scales_exp = torch.log2(converted_ct) + block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2 + assert torch.equal(scales_exp, block_max_exp)