|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +from typing import Dict, Optional, Tuple |
| 17 | + |
| 18 | +import numpy |
| 19 | +import torch |
| 20 | +from compressed_tensors.compressors.base import BaseCompressor |
| 21 | +from compressed_tensors.compressors.quantized_compressors.base import ( |
| 22 | + BaseQuantizationCompressor, |
| 23 | +) |
| 24 | +from compressed_tensors.config import CompressionFormat |
| 25 | +from compressed_tensors.quantization import QuantizationArgs |
| 26 | +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize |
| 27 | +from torch import Tensor |
| 28 | + |
| 29 | + |
| 30 | +__all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"] |
| 31 | + |
| 32 | +FLOAT_TO_E2M1 = [ |
| 33 | + 0.0, |
| 34 | + 0.5, |
| 35 | + 1.0, |
| 36 | + 1.5, |
| 37 | + 2.0, |
| 38 | + 3.0, |
| 39 | + 4.0, |
| 40 | + 6.0, |
| 41 | +] |
| 42 | + |
| 43 | + |
| 44 | +@BaseCompressor.register(name=CompressionFormat.nvfp4_pack_quantized.value) |
| 45 | +class NVFP4PackedCompressor(BaseQuantizationCompressor): |
| 46 | + """ |
| 47 | + Implements compression of FP4 values. Weights of each quantized layer |
| 48 | + are packed into uint8. Only supports symmetric weight compression for now. |
| 49 | + """ |
| 50 | + |
| 51 | + @property |
| 52 | + def compression_param_names(self) -> Tuple[str]: |
| 53 | + """ |
| 54 | + Returns a tuple of compression parameter names introduced by |
| 55 | + the compressor during compression |
| 56 | + """ |
| 57 | + return ( |
| 58 | + "weight_packed", |
| 59 | + "weight_scale", |
| 60 | + "weight_zero_point", |
| 61 | + "weight_global_scale", |
| 62 | + ) |
| 63 | + |
| 64 | + def compress_weight( |
| 65 | + self, |
| 66 | + weight: Tensor, |
| 67 | + scale: Tensor, |
| 68 | + global_scale: Tensor, |
| 69 | + quantization_args: QuantizationArgs, |
| 70 | + device: Optional[torch.device] = None, |
| 71 | + zero_point: Optional[torch.Tensor] = None, |
| 72 | + g_idx: Optional[torch.Tensor] = None, |
| 73 | + ) -> Dict[str, torch.Tensor]: |
| 74 | + |
| 75 | + quantized_weight = quantize( |
| 76 | + x=weight, |
| 77 | + scale=scale, |
| 78 | + global_scale=global_scale, |
| 79 | + zero_point=zero_point, |
| 80 | + args=quantization_args, |
| 81 | + ) |
| 82 | + compressed_dict = {} |
| 83 | + weight_packed = pack_fp4_to_uint8(quantized_weight) |
| 84 | + if device is not None: |
| 85 | + weight_packed = weight_packed.to(device) |
| 86 | + compressed_dict["weight_packed"] = weight_packed |
| 87 | + return compressed_dict |
| 88 | + |
| 89 | + def decompress_weight( |
| 90 | + self, |
| 91 | + compressed_data: Dict[str, Tensor], |
| 92 | + quantization_args: Optional[QuantizationArgs] = None, |
| 93 | + ) -> torch.Tensor: |
| 94 | + |
| 95 | + weight = compressed_data["weight_packed"] |
| 96 | + scale = compressed_data["weight_scale"] |
| 97 | + global_scale = compressed_data["weight_global_scale"] |
| 98 | + m, n = weight.shape |
| 99 | + # TODO: use a user provided dequant dtype |
| 100 | + unpacked = unpack_fp4_from_uint8(weight, m, n * 2) |
| 101 | + decompressed_weight = dequantize( |
| 102 | + x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype |
| 103 | + ) |
| 104 | + |
| 105 | + return decompressed_weight |
| 106 | + |
| 107 | + |
| 108 | +def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: |
| 109 | + """ |
| 110 | + Packs a tensor with values in the fp4 range into uint8. |
| 111 | + As there are 16 valid fp4 values, two fp4 values can be |
| 112 | + packed into one uint8. Each fp4 value is mapped to its |
| 113 | + particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped |
| 114 | + to index 7) which is then represented using 4 bits. Consecutive |
| 115 | + pairs of 4 bits are then packed into an uint8. |
| 116 | +
|
| 117 | + :param x: tensor to pack |
| 118 | + returns: a packed tensor in uint8 |
| 119 | + """ |
| 120 | + |
| 121 | + m, n = x.shape |
| 122 | + device = x.device |
| 123 | + |
| 124 | + # Create lookup table for FP4 values to indices |
| 125 | + # Map the absolute values to 0-7 indices |
| 126 | + kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype) |
| 127 | + |
| 128 | + # Find closest valid FP4 value index for each element |
| 129 | + abs_x = torch.abs(x) |
| 130 | + abs_indices = torch.zeros_like(abs_x, dtype=torch.long) |
| 131 | + for i, val in enumerate(kE2M1): |
| 132 | + abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices) |
| 133 | + |
| 134 | + # Apply sign bit (bit 3) to get final 4-bit representation |
| 135 | + indices = abs_indices + (torch.signbit(x) << 3).to(torch.long) |
| 136 | + |
| 137 | + # Reshape to prepare for packing pairs of values |
| 138 | + indices = indices.reshape(-1) |
| 139 | + |
| 140 | + # Handle odd length by padding if necessary |
| 141 | + if indices.numel() % 2 != 0: |
| 142 | + indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) |
| 143 | + |
| 144 | + # Reshape to pair consecutive elements |
| 145 | + indices = indices.reshape(-1, 2) |
| 146 | + |
| 147 | + # Pack pairs of 4-bit values into 8-bit values |
| 148 | + packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) |
| 149 | + |
| 150 | + return packed.reshape(m, n // 2) |
| 151 | + |
| 152 | + |
| 153 | +kE2M1ToFloat = torch.tensor( |
| 154 | + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 |
| 155 | +) |
| 156 | + |
| 157 | +# reference: : https://github.com/vllm-project/vllm/pull/16362 |
| 158 | +def unpack_fp4_from_uint8( |
| 159 | + a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 |
| 160 | +) -> torch.Tensor: |
| 161 | + """ |
| 162 | + Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values |
| 163 | + (i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive |
| 164 | + fp4 value). The bits represent an index, which are mapped to an fp4 value. |
| 165 | +
|
| 166 | + :param a: tensor to unpack |
| 167 | + :param m: original dim 0 size of the unpacked tensor |
| 168 | + :param n: original dim 1 size of the unpacked tensor |
| 169 | + :param dtype: dense dtype to cast the unpacked tensor to |
| 170 | + """ |
| 171 | + assert a.dtype == torch.uint8 |
| 172 | + |
| 173 | + # Vectorized nibble processing |
| 174 | + a_flat = a.flatten() |
| 175 | + high = (a_flat & 0xF0) >> 4 # Upper nibbles |
| 176 | + low = a_flat & 0x0F # Lower nibbles |
| 177 | + |
| 178 | + # Combine nibbles for batch processing |
| 179 | + combined = torch.stack((low, high), dim=1).flatten() |
| 180 | + |
| 181 | + # Vectorized sign and magnitude extraction |
| 182 | + signs = (combined & 0x08).to(torch.bool) # Sign bits |
| 183 | + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices |
| 184 | + |
| 185 | + # Device-aware lookup and sign application |
| 186 | + kE2M1 = kE2M1ToFloat.to(device=a.device) |
| 187 | + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) |
| 188 | + |
| 189 | + # Reshape to final form |
| 190 | + return values.reshape(m, n).to(dtype=dtype) |
0 commit comments