Skip to content

Commit 5c6fd5d

Browse files
authored
[Compressor][NVFP4] Support FP4 Compression (#311)
* add nvfp4 compressor * add docstring * update docstring
1 parent 4759a86 commit 5c6fd5d

File tree

4 files changed

+235
-0
lines changed

4 files changed

+235
-0
lines changed

src/compressed_tensors/compressors/quantized_compressors/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515

1616
from .base import *
1717
from .naive_quantized import *
18+
from .nvfp4_quantized import *
1819
from .pack_quantized import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Dict, Optional, Tuple
17+
18+
import numpy
19+
import torch
20+
from compressed_tensors.compressors.base import BaseCompressor
21+
from compressed_tensors.compressors.quantized_compressors.base import (
22+
BaseQuantizationCompressor,
23+
)
24+
from compressed_tensors.config import CompressionFormat
25+
from compressed_tensors.quantization import QuantizationArgs
26+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
27+
from torch import Tensor
28+
29+
30+
__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
31+
32+
FLOAT_TO_E2M1 = [
33+
0.0,
34+
0.5,
35+
1.0,
36+
1.5,
37+
2.0,
38+
3.0,
39+
4.0,
40+
6.0,
41+
]
42+
43+
44+
@BaseCompressor.register(name=CompressionFormat.nvfp4_pack_quantized.value)
45+
class NVFP4PackedCompressor(BaseQuantizationCompressor):
46+
"""
47+
Implements compression of FP4 values. Weights of each quantized layer
48+
are packed into uint8. Only supports symmetric weight compression for now.
49+
"""
50+
51+
@property
52+
def compression_param_names(self) -> Tuple[str]:
53+
"""
54+
Returns a tuple of compression parameter names introduced by
55+
the compressor during compression
56+
"""
57+
return (
58+
"weight_packed",
59+
"weight_scale",
60+
"weight_zero_point",
61+
"weight_global_scale",
62+
)
63+
64+
def compress_weight(
65+
self,
66+
weight: Tensor,
67+
scale: Tensor,
68+
global_scale: Tensor,
69+
quantization_args: QuantizationArgs,
70+
device: Optional[torch.device] = None,
71+
zero_point: Optional[torch.Tensor] = None,
72+
g_idx: Optional[torch.Tensor] = None,
73+
) -> Dict[str, torch.Tensor]:
74+
75+
quantized_weight = quantize(
76+
x=weight,
77+
scale=scale,
78+
global_scale=global_scale,
79+
zero_point=zero_point,
80+
args=quantization_args,
81+
)
82+
compressed_dict = {}
83+
weight_packed = pack_fp4_to_uint8(quantized_weight)
84+
if device is not None:
85+
weight_packed = weight_packed.to(device)
86+
compressed_dict["weight_packed"] = weight_packed
87+
return compressed_dict
88+
89+
def decompress_weight(
90+
self,
91+
compressed_data: Dict[str, Tensor],
92+
quantization_args: Optional[QuantizationArgs] = None,
93+
) -> torch.Tensor:
94+
95+
weight = compressed_data["weight_packed"]
96+
scale = compressed_data["weight_scale"]
97+
global_scale = compressed_data["weight_global_scale"]
98+
m, n = weight.shape
99+
# TODO: use a user provided dequant dtype
100+
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
101+
decompressed_weight = dequantize(
102+
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
103+
)
104+
105+
return decompressed_weight
106+
107+
108+
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
109+
"""
110+
Packs a tensor with values in the fp4 range into uint8.
111+
As there are 16 valid fp4 values, two fp4 values can be
112+
packed into one uint8. Each fp4 value is mapped to its
113+
particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped
114+
to index 7) which is then represented using 4 bits. Consecutive
115+
pairs of 4 bits are then packed into an uint8.
116+
117+
:param x: tensor to pack
118+
returns: a packed tensor in uint8
119+
"""
120+
121+
m, n = x.shape
122+
device = x.device
123+
124+
# Create lookup table for FP4 values to indices
125+
# Map the absolute values to 0-7 indices
126+
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)
127+
128+
# Find closest valid FP4 value index for each element
129+
abs_x = torch.abs(x)
130+
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
131+
for i, val in enumerate(kE2M1):
132+
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
133+
134+
# Apply sign bit (bit 3) to get final 4-bit representation
135+
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
136+
137+
# Reshape to prepare for packing pairs of values
138+
indices = indices.reshape(-1)
139+
140+
# Handle odd length by padding if necessary
141+
if indices.numel() % 2 != 0:
142+
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
143+
144+
# Reshape to pair consecutive elements
145+
indices = indices.reshape(-1, 2)
146+
147+
# Pack pairs of 4-bit values into 8-bit values
148+
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)
149+
150+
return packed.reshape(m, n // 2)
151+
152+
153+
kE2M1ToFloat = torch.tensor(
154+
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
155+
)
156+
157+
# reference: : https://github.com/vllm-project/vllm/pull/16362
158+
def unpack_fp4_from_uint8(
159+
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
160+
) -> torch.Tensor:
161+
"""
162+
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
163+
(i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive
164+
fp4 value). The bits represent an index, which are mapped to an fp4 value.
165+
166+
:param a: tensor to unpack
167+
:param m: original dim 0 size of the unpacked tensor
168+
:param n: original dim 1 size of the unpacked tensor
169+
:param dtype: dense dtype to cast the unpacked tensor to
170+
"""
171+
assert a.dtype == torch.uint8
172+
173+
# Vectorized nibble processing
174+
a_flat = a.flatten()
175+
high = (a_flat & 0xF0) >> 4 # Upper nibbles
176+
low = a_flat & 0x0F # Lower nibbles
177+
178+
# Combine nibbles for batch processing
179+
combined = torch.stack((low, high), dim=1).flatten()
180+
181+
# Vectorized sign and magnitude extraction
182+
signs = (combined & 0x08).to(torch.bool) # Sign bits
183+
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
184+
185+
# Device-aware lookup and sign application
186+
kE2M1 = kE2M1ToFloat.to(device=a.device)
187+
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
188+
189+
# Reshape to final form
190+
return values.reshape(m, n).to(dtype=dtype)

src/compressed_tensors/config/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
3232
naive_quantized = "naive-quantized"
3333
pack_quantized = "pack-quantized"
3434
marlin_24 = "marlin-24"
35+
nvfp4_pack_quantized = "nvfp4-pack-quantized"
3536

3637

3738
@unique
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.compressors.quantized_compressors.nvfp4_quantized import (
17+
pack_fp4_to_uint8,
18+
unpack_fp4_from_uint8,
19+
)
20+
21+
22+
def test_pack_unpack():
23+
x = torch.Tensor(
24+
[
25+
[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000],
26+
[-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000],
27+
[-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000],
28+
[1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000],
29+
]
30+
)
31+
32+
dense_dtype = torch.bfloat16
33+
x = x.to(dense_dtype)
34+
m, n = x.shape
35+
packed = pack_fp4_to_uint8(x)
36+
assert packed.dtype == torch.uint8
37+
unpacked = unpack_fp4_from_uint8(packed, m, n, dtype=dense_dtype)
38+
assert unpacked.dtype == dense_dtype
39+
40+
assert torch.equal(unpacked, x) # misleading as -0 and 0 are considered equal
41+
sign_bitx = torch.signbit(x)
42+
sign_bitout = torch.signbit(unpacked)
43+
assert torch.equal(sign_bitout, sign_bitx)

0 commit comments

Comments
 (0)