Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
ocp_mx_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -435,13 +434,9 @@ def __init__(
self.static_input_scales = not self.input_quant.get("is_dynamic")

self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"]
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)

self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
)

if self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX_MoEMethod with static input scales is currently "
Expand All @@ -450,18 +445,15 @@ def __init__(

self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()

self.emulate = not current_platform.supports_mx() or not (
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
)
self.emulate = not current_platform.supports_mx() or not self.use_rocm_aiter_moe
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}",
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
"layers computed in high precision.",
)
else:
logger.warning_once(
Expand Down Expand Up @@ -578,14 +570,10 @@ def process_weights_after_loading(self, layer):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
# The default mxfp4 recipe is with a16 dynamic quantzied
# and wmxfp4.
return mxfp4_w4a16_moe_quant_config(
layer.w13_weight_scale, layer.w2_weight_scale
)

@property
Expand Down Expand Up @@ -639,7 +627,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
quant_config=self.moe_quant_config,
quant_config=self.get_fused_moe_quant_config(layer),
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
Expand Down