[Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend#18951
[Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend#18951tlopex merged 3 commits intoapache:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the MatMulInteger operator for the ONNX frontend in Relax, enabling INT8/UINT8 quantized matrix multiplication with support for optional zero points. The changes include input widening to int32 and broadcasting logic for zero-point tensors. Feedback identifies a dynamic shape safety issue in the zero-point handling and a logic error in the test helper's input placeholder generation. Additionally, it is recommended to add test coverage for per-channel zero-point configurations to verify the 1-D reshaping logic.
| if a_zero_point is not None: | ||
| a_zp = relax.op.astype( | ||
| a_zero_point, "int32" | ||
| ) # Ensure zero point is int32 for subtraction | ||
| a_zp = bb.normalize(a_zp) # Normalize the expr so struct_info gets populated | ||
| a_zp_shape = [dim.value for dim in a_zp.struct_info.shape] | ||
|
|
||
| # Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise | ||
| # N-D case: spec says shape is [D1, D2, M, 1], which already broadcasts correctly (no need to reshape) | ||
| if len(a_zp_shape) == 1: | ||
| a_zp = relax.op.reshape(a_zp, [a_zp_shape[0], 1]) | ||
|
|
||
| a = relax.op.subtract(a, a_zp) | ||
|
|
||
| # Per-column zero point handling is analogous to per-row, but we reshape to [1, N] for broadcasting over columns of [K, N] | ||
| # N-D case: spec says shape is [D1, D2, 1, N], which already broadcasts correctly (no need to reshape) | ||
| if b_zero_point is not None: | ||
| b_zp = relax.op.astype(b_zero_point, "int32") | ||
| b = relax.op.subtract(b, b_zp) |
There was a problem hiding this comment.
There are a couple of issues in the zero-point handling logic:
- Dynamic Shape-Safety: The expression
[dim.value for dim in a_zp.struct_info.shape]on line 4063 is not safe for dynamic shapes, as.valueis not valid fortir.SizeVar. It should bea_zp.struct_info.shape.values. - Inconsistent
b_zero_pointHandling: The logic forb_zero_pointis missing the reshape that is present fora_zero_point. While broadcasting might work in some cases, explicitly reshapingb_zpimproves clarity, safety, and consistency with the comment on line 4072.
Here is a suggested refactoring that addresses both points for a_zero_point and b_zero_point:
if a_zero_point is not None:
a_zp = relax.op.astype(
a_zero_point, "int32"
) # Ensure zero point is int32 for subtraction
a_zp = bb.normalize(a_zp) # Normalize the expr so struct_info gets populated
a_zp_shape = a_zp.struct_info.shape.values
# Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise
# N-D case: spec says shape is [D1, D2, M, 1], which already broadcasts correctly (no need to reshape)
if len(a_zp_shape) == 1:
a_zp = relax.op.reshape(a_zp, [a_zp_shape[0], 1])
a = relax.op.subtract(a, a_zp)
# Per-column zero point handling is analogous to per-row, but we reshape to [1, N] for broadcasting over columns of [K, N]
# N-D case: spec says shape is [D1, D2, 1, N], which already broadcasts correctly (no need to reshape)
if b_zero_point is not None:
b_zp = relax.op.astype(b_zero_point, "int32")
b_zp = bb.normalize(b_zp)
b_zp_shape = b_zp.struct_info.shape.values
if len(b_zp_shape) == 1:
b_zp = relax.op.reshape(b_zp, [1, b_zp_shape[0]])
b = relax.op.subtract(b, b_zp)| if a_zp_array is not None: | ||
| _add_zp("a_zero_point", a_zp_array, A_dtype) | ||
| else: | ||
| node_inputs.append("") | ||
|
|
||
| if b_zp_array is not None: | ||
| _add_zp("b_zero_point", b_zp_array, B_dtype) |
There was a problem hiding this comment.
There's a logic error here when creating the node_inputs. When both a_zp_array and b_zp_array are None, an unnecessary empty string is appended for a_zero_point, resulting in a 3-input MatMulInteger node instead of a 2-input one.
Optional inputs should only be added as placeholders if a later optional input is present. The logic can be simplified to correctly handle all cases of optional inputs.
| if a_zp_array is not None: | |
| _add_zp("a_zero_point", a_zp_array, A_dtype) | |
| else: | |
| node_inputs.append("") | |
| if b_zp_array is not None: | |
| _add_zp("b_zero_point", b_zp_array, B_dtype) | |
| if a_zp_array is not None: | |
| _add_zp("a_zero_point", a_zp_array, A_dtype) | |
| elif b_zp_array is not None: | |
| node_inputs.append("") # Placeholder for a_zero_point | |
| if b_zp_array is not None: | |
| _add_zp("b_zero_point", b_zp_array, B_dtype) |
| (np.int8, np.int8, np.int8(1), np.int8(2)), | ||
| ], | ||
| ) | ||
| def test_matmulinteger(A_dtype, B_dtype, a_zp, b_zp): |
There was a problem hiding this comment.
The tests for MatMulInteger are great for scalar zero points, but they don't cover the per-channel (per-row/per-column) quantization cases. The implementation in onnx_frontend.py includes logic to handle 1-D zero-point tensors, but this is currently untested.
Please consider adding a new test case to cover per-channel zero points to ensure the reshape and broadcasting logic is correct. For example:
def test_matmulinteger_per_channel_zp():
A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
a_zp = np.arange(4, dtype=np.int8) # per-row
b_zp = np.arange(6, dtype=np.int8) # per-column
model = _make_matmulinteger_model(
[4, 8],
[8, 6],
np.int8,
np.int8,
a_zp_array=a_zp,
b_zp_array=b_zp,
)
check_correctness(model, inputs={"A": A, "B": B}, opset=10)
tlopex
left a comment
There was a problem hiding this comment.
A couple of follow-ups:
-
Please avoid using
dim.valuehere. It will break on symbolic shapes.relax.op.expand_dims(a_zp, axis=-1)would be simpler and more robust than extracting concrete shape values for the reshape. The same applies tob_zpif it has a similar path. -
Please add a test for 1-D zero points as well. Right now all tests use scalar zero points, so the reshape/broadcast path for per-row or per-column zero points is not covered.
…er, add per-channel zp tests Changes based on reviewer feedback: - Replace dim.value shape extraction with relax.op.expand_dims for zero-point reshaping, which is safe for symbolic/dynamic shapes unlike extracting concrete dimension values via struct_info.shape - Fix _make_matmulinteger_model helper: zero-points were incorrectly added to both graph_inputs and initializers, causing duplicate Relax variable names and malformed IR. Now added as initializers only. - Fix node_inputs placeholder logic: empty string for absent a_zero_point is now only appended when b_zero_point is present, avoiding a spurious 3-input node when both zero-points are absent. - Add test_matmulinteger_per_channel_zp: verifies 1-D per-row a_zero_point [M] and per-col b_zero_point [N], exercising the expand_dims path in the converter. Runs TVM-only against a NumPy reference since ORT CPU does not support per-row a_zero_point despite the ONNX spec allowing it. - Add test_matmulinteger_per_channel_zp_ort_limitation: xfail test that documents the ORT CPU kernel limitation (strict=True so we are alerted if ORT ever fixes this). - Mark int8/uint8 mixed dtype test as xfail(strict=False) since some older ORT versions do not implement this combination, but newer versions do. Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
|
Thank you @tlopex for the review, I addressed everything in the follow-up commit. One thing worth noting. Two ORT limitations surfaced during testing: 1. Per-row 2. Mixed Edit: I mentioned that the ORT limitation is in the CPU kernel in the commit message, but after testing, it fails on the GPU as well. The error message looks something like this: |
tlopex
left a comment
There was a problem hiding this comment.
LGTM! Could you resolve the conflicts?
Thank you, Done |
Summary
Implements the
MatMulIntegeroperator (opset 10) in the Relax ONNX frontend — INT8 matrix multiplication. Required for quantized model inference (e.g. ONNX QDQ models).Closes #18945 (Tier 1 — MatMulInteger operator)
Tests
int8/uint8dtype combinations, with and without scalar zero points