Skip to content

[Compressor][NVFP4] Support FP4 Compression #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .naive_quantized import *
from .nvfp4_quantized import *
from .pack_quantized import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, Optional, Tuple

import numpy
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.quantized_compressors.base import (
BaseQuantizationCompressor,
)
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
from torch import Tensor


__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]

FLOAT_TO_E2M1 = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
]


@BaseCompressor.register(name=CompressionFormat.nvfp4_pack_quantized.value)
class NVFP4PackedCompressor(BaseQuantizationCompressor):
"""
Implements compression of FP4 values. Weights of each quantized layer
are packed into uint8. Only supports symmetric weight compression for now.
"""

@property
def compression_param_names(self) -> Tuple[str]:
"""
Returns a tuple of compression parameter names introduced by
the compressor during compression
"""
return (
"weight_packed",
"weight_scale",
"weight_zero_point",
"weight_global_scale",
)

def compress_weight(
self,
weight: Tensor,
scale: Tensor,
global_scale: Tensor,
quantization_args: QuantizationArgs,
device: Optional[torch.device] = None,
zero_point: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:

quantized_weight = quantize(
x=weight,
scale=scale,
global_scale=global_scale,
zero_point=zero_point,
args=quantization_args,
)
compressed_dict = {}
weight_packed = pack_fp4_to_uint8(quantized_weight)
if device is not None:
weight_packed = weight_packed.to(device)
compressed_dict["weight_packed"] = weight_packed
return compressed_dict

def decompress_weight(
self,
compressed_data: Dict[str, Tensor],
quantization_args: Optional[QuantizationArgs] = None,
) -> torch.Tensor:

weight = compressed_data["weight_packed"]
scale = compressed_data["weight_scale"]
global_scale = compressed_data["weight_global_scale"]
m, n = weight.shape
# TODO: use a user provided dequant dtype
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
decompressed_weight = dequantize(
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
)

return decompressed_weight


def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor with values in the fp4 range into uint8.
As there are 16 valid fp4 values, two fp4 values can be
packed into one uint8. Each fp4 value is mapped to its
particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped
to index 7) which is then represented using 4 bits. Consecutive
pairs of 4 bits are then packed into an uint8.

:param x: tensor to pack
returns: a packed tensor in uint8
"""

m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
for i, val in enumerate(kE2M1):
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)

# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)

return packed.reshape(m, n // 2)


kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)

# reference: : https://github.com/vllm-project/vllm/pull/16362
def unpack_fp4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
"""
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
(i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive
fp4 value). The bits represent an index, which are mapped to an fp4 value.

:param a: tensor to unpack
:param m: original dim 0 size of the unpacked tensor
:param n: original dim 1 size of the unpacked tensor
:param dtype: dense dtype to cast the unpacked tensor to
"""
assert a.dtype == torch.uint8

# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles

# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()

# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices

# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

# Reshape to final form
return values.reshape(m, n).to(dtype=dtype)
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CompressionFormat(Enum):
naive_quantized = "naive-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
nvfp4_pack_quantized = "nvfp4-pack-quantized"


@unique
Expand Down
43 changes: 43 additions & 0 deletions tests/test_compressors/quantized_compressors/test_nvfp4_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from compressed_tensors.compressors.quantized_compressors.nvfp4_quantized import (
pack_fp4_to_uint8,
unpack_fp4_from_uint8,
)


def test_pack_unpack():
x = torch.Tensor(
[
[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000],
[-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000],
[-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000],
[1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000],
]
)

dense_dtype = torch.bfloat16
x = x.to(dense_dtype)
m, n = x.shape
packed = pack_fp4_to_uint8(x)
assert packed.dtype == torch.uint8
unpacked = unpack_fp4_from_uint8(packed, m, n, dtype=dense_dtype)
assert unpacked.dtype == dense_dtype

assert torch.equal(unpacked, x) # misleading as -0 and 0 are considered equal
sign_bitx = torch.signbit(x)
sign_bitout = torch.signbit(unpacked)
assert torch.equal(sign_bitout, sign_bitx)