From 54cac384d9415d4874cc247fde8f3fb84821a766 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 27 Nov 2025 18:04:52 +0000 Subject: [PATCH 1/4] [bugfx mxpf4] Infer mxfp4 quantmethod from layer --- vllm/model_executor/layers/fused_moe/config.py | 2 +- vllm/model_executor/layers/quantization/quark/quark_moe.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 1826fafa8c4f..26d56731abef 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -343,7 +343,7 @@ def ocp_mx_scheme(self) -> str | None: @property def use_mxfp4_w4a16(self) -> bool: - return self._a1.dtype is None and self._w1.dtype == "mxfp4" + return self._w1.dtype == "mxfp4" @property def use_nvfp4_w4a4(self) -> bool: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6d9646642be3..6f746d9b7ab1 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -623,6 +623,7 @@ def apply( if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, + QuantMethod ) if hasattr(torch, "float4_e2m1fn_x2"): @@ -639,7 +640,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - quant_config=self.moe_quant_config, + quant_config=self.get_fused_moe_quant_config(layer), ) else: from vllm.model_executor.layers.fused_moe import fused_experts From dd265e8974fd287c91d29fa5cae0b746ce110268 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 28 Nov 2025 03:56:41 +0000 Subject: [PATCH 2/4] Use aw4a16 config --- .../model_executor/layers/fused_moe/config.py | 3 +-- .../layers/quantization/quark/quark_moe.py | 27 +++++-------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 26d56731abef..b93bc10a3108 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -343,7 +343,7 @@ def ocp_mx_scheme(self) -> str | None: @property def use_mxfp4_w4a16(self) -> bool: - return self._w1.dtype == "mxfp4" + return self._a1.dtype is None and self._w1.dtype == "mxfp4" @property def use_nvfp4_w4a4(self) -> bool: @@ -458,7 +458,6 @@ def make( "mxfp6_e3m2", "mxfp6_e2m3", } - if weight_dtype is None: weight_dtype = quant_dtype diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6f746d9b7ab1..293b9055e887 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - ocp_mx_moe_quant_config, + mxfp4_w4a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -27,7 +27,6 @@ ) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, - OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -435,13 +434,9 @@ def __init__( self.static_input_scales = not self.input_quant.get("is_dynamic") self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") - self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") + self.input_dtype = self.input_quant["dtype"] self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) - self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( - self.input_dtype, self.weight_dtype - ) - if self.static_input_scales: raise NotImplementedError( "QuarkOCP_MX_MoEMethod with static input scales is currently " @@ -450,9 +445,7 @@ def __init__( self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() - self.emulate = not current_platform.supports_mx() or not ( - self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" - ) + self.emulate = not current_platform.supports_mx() or not self.use_rocm_aiter_moe if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " @@ -578,15 +571,8 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return ocp_mx_moe_quant_config( - quant_dtype=self.input_dtype, - weight_dtype=self.weight_dtype, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - ) + return mxfp4_w4a16_moe_quant_config(layer.w13_weight_scale, + layer.w2_weight_scale) @property def allow_inplace(self) -> bool: @@ -622,8 +608,7 @@ def apply( if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - QuantMethod + rocm_aiter_fused_experts ) if hasattr(torch, "float4_e2m1fn_x2"): From 760ad2d4234814ce43e9554fba872eea0be112b8 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 28 Nov 2025 03:59:28 +0000 Subject: [PATCH 3/4] lint --- vllm/model_executor/layers/fused_moe/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index b93bc10a3108..1826fafa8c4f 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -458,6 +458,7 @@ def make( "mxfp6_e3m2", "mxfp6_e2m3", } + if weight_dtype is None: weight_dtype = quant_dtype From 7da5e5f7ca5c6980c1252b146e6e3c8facfe9461 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 28 Nov 2025 04:40:06 +0000 Subject: [PATCH 4/4] Add comments Signed-off-by: ZhiweiYan-96 --- .../layers/quantization/quark/quark_moe.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 293b9055e887..4b9e33a44d14 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -449,12 +449,11 @@ def __init__( if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " - f"ocp_mx_scheme={self.ocp_mx_scheme}) " + f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}", "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision." + "layers computed in high precision.", ) else: logger.warning_once( @@ -571,8 +570,11 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return mxfp4_w4a16_moe_quant_config(layer.w13_weight_scale, - layer.w2_weight_scale) + # The default mxfp4 recipe is with a16 dynamic quantzied + # and wmxfp4. + return mxfp4_w4a16_moe_quant_config( + layer.w13_weight_scale, layer.w2_weight_scale + ) @property def allow_inplace(self) -> bool: @@ -608,7 +610,7 @@ def apply( if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts + rocm_aiter_fused_experts, ) if hasattr(torch, "float4_e2m1fn_x2"):