diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 13957a96deca..b5eb50806d7d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -6,8 +6,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter -from vllm._custom_ops import (cutlass_scaled_fp4_mm, - cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -18,8 +17,9 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, is_fp4_marlin_supported, - prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) + is_fp4_marlin_supported, prepare_moe_fp4_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + dequantize_to_dtype, ref_nvfp4_quant) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -387,27 +387,10 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) - layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, - requires_grad=False) - - # Swizzle the weight blockscale. - # contracting dimension is input dimension - # block_size = 16; - assert (layer.weight_scale.shape[1] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Block scale must be represented as FP8-E4M3") swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) - layer.weight = Parameter(layer.weight.data, requires_grad=False) - - if self.use_marlin: - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - del layer.weight_scale_swizzled def apply( self, @@ -415,38 +398,33 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if self.use_marlin: - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_scale_2=layer.weight_scale_2, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias) + # for input only the contracting dimension has a constraint. + x_m, x_k = x.shape + block_size = 16 output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] - - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - s_quant = 1 / layer.input_scale - x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) - - # validate dtypes of quantized input, input block scale, - # weight and weight_blockscale - assert (x_fp4.dtype == torch.uint8) - assert (layer.weight.dtype == torch.uint8) - assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) - assert (layer.alpha.dtype == torch.float32) - - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, - output_dtype) - if bias is not None: - out = out + bias - return out.view(*output_shape) + + # quantize input to (FP4 and interleaved block scale) + x_global_scale = 1 / layer.input_scale + x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size) + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size) + x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + # dequantize weight + w_fp4 = layer.weight.data.view(torch.uint8) + w_blockscale = layer.weight_scale_swizzled.data + w_global_scale = 1 / layer.weight_scale_2 + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + output_dtype, x.device, block_size) + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + del w_dq, x_dq + return out class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index f292208311e2..61393a2c32f1 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import torch -__all__ = [ - "break_fp4_bytes", - "dequantize_to_dtype", -] +from vllm.scalar_type import scalar_types + +__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], dtype=torch.float32) @@ -59,3 +60,44 @@ def dequantize_to_dtype(tensor_fp4, # scale the tensor out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) return out.to(dtype) + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def ref_nvfp4_quant(x, global_scale, block_size): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // block_size, block_size)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = torch.clamp(scale, max=448, min=-448) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + # both outputs are float32 + return cast_to_fp4(clipped_x), scale.squeeze(-1)