Skip to content

[Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend#18951

Merged
tlopex merged 3 commits intoapache:mainfrom
OmarAzizi:onnx-frontend-matmulinteger
Mar 29, 2026
Merged

[Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend#18951
tlopex merged 3 commits intoapache:mainfrom
OmarAzizi:onnx-frontend-matmulinteger

Conversation

@OmarAzizi
Copy link
Copy Markdown
Contributor

Summary

Implements the MatMulInteger operator (opset 10) in the Relax ONNX frontend — INT8 matrix multiplication. Required for quantized model inference (e.g. ONNX QDQ models).

Closes #18945 (Tier 1 — MatMulInteger operator)

Tests

  • All 4 int8/uint8 dtype combinations, with and without scalar zero points
  • 3-D and 4-D batched matmul

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the MatMulInteger operator for the ONNX frontend in Relax, enabling INT8/UINT8 quantized matrix multiplication with support for optional zero points. The changes include input widening to int32 and broadcasting logic for zero-point tensors. Feedback identifies a dynamic shape safety issue in the zero-point handling and a logic error in the test helper's input placeholder generation. Additionally, it is recommended to add test coverage for per-channel zero-point configurations to verify the 1-D reshaping logic.

Comment on lines +4058 to +4076
if a_zero_point is not None:
a_zp = relax.op.astype(
a_zero_point, "int32"
) # Ensure zero point is int32 for subtraction
a_zp = bb.normalize(a_zp) # Normalize the expr so struct_info gets populated
a_zp_shape = [dim.value for dim in a_zp.struct_info.shape]

# Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise
# N-D case: spec says shape is [D1, D2, M, 1], which already broadcasts correctly (no need to reshape)
if len(a_zp_shape) == 1:
a_zp = relax.op.reshape(a_zp, [a_zp_shape[0], 1])

a = relax.op.subtract(a, a_zp)

# Per-column zero point handling is analogous to per-row, but we reshape to [1, N] for broadcasting over columns of [K, N]
# N-D case: spec says shape is [D1, D2, 1, N], which already broadcasts correctly (no need to reshape)
if b_zero_point is not None:
b_zp = relax.op.astype(b_zero_point, "int32")
b = relax.op.subtract(b, b_zp)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues in the zero-point handling logic:

  1. Dynamic Shape-Safety: The expression [dim.value for dim in a_zp.struct_info.shape] on line 4063 is not safe for dynamic shapes, as .value is not valid for tir.SizeVar. It should be a_zp.struct_info.shape.values.
  2. Inconsistent b_zero_point Handling: The logic for b_zero_point is missing the reshape that is present for a_zero_point. While broadcasting might work in some cases, explicitly reshaping b_zp improves clarity, safety, and consistency with the comment on line 4072.

Here is a suggested refactoring that addresses both points for a_zero_point and b_zero_point:

        if a_zero_point is not None:
            a_zp = relax.op.astype(
                a_zero_point, "int32"
            )  # Ensure zero point is int32 for subtraction
            a_zp = bb.normalize(a_zp)  # Normalize the expr so struct_info gets populated
            a_zp_shape = a_zp.struct_info.shape.values

            # Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise
            # N-D case: spec says shape is [D1, D2, M, 1], which already broadcasts correctly (no need to reshape)
            if len(a_zp_shape) == 1:
                a_zp = relax.op.reshape(a_zp, [a_zp_shape[0], 1])

            a = relax.op.subtract(a, a_zp)

        # Per-column zero point handling is analogous to per-row, but we reshape to [1, N] for broadcasting over columns of [K, N]
        # N-D case: spec says shape is [D1, D2, 1, N], which already broadcasts correctly (no need to reshape)
        if b_zero_point is not None:
            b_zp = relax.op.astype(b_zero_point, "int32")
            b_zp = bb.normalize(b_zp)
            b_zp_shape = b_zp.struct_info.shape.values

            if len(b_zp_shape) == 1:
                b_zp = relax.op.reshape(b_zp, [1, b_zp_shape[0]])
            b = relax.op.subtract(b, b_zp)

Comment on lines +4438 to +4444
if a_zp_array is not None:
_add_zp("a_zero_point", a_zp_array, A_dtype)
else:
node_inputs.append("")

if b_zp_array is not None:
_add_zp("b_zero_point", b_zp_array, B_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a logic error here when creating the node_inputs. When both a_zp_array and b_zp_array are None, an unnecessary empty string is appended for a_zero_point, resulting in a 3-input MatMulInteger node instead of a 2-input one.

Optional inputs should only be added as placeholders if a later optional input is present. The logic can be simplified to correctly handle all cases of optional inputs.

Suggested change
if a_zp_array is not None:
_add_zp("a_zero_point", a_zp_array, A_dtype)
else:
node_inputs.append("")
if b_zp_array is not None:
_add_zp("b_zero_point", b_zp_array, B_dtype)
if a_zp_array is not None:
_add_zp("a_zero_point", a_zp_array, A_dtype)
elif b_zp_array is not None:
node_inputs.append("") # Placeholder for a_zero_point
if b_zp_array is not None:
_add_zp("b_zero_point", b_zp_array, B_dtype)

(np.int8, np.int8, np.int8(1), np.int8(2)),
],
)
def test_matmulinteger(A_dtype, B_dtype, a_zp, b_zp):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tests for MatMulInteger are great for scalar zero points, but they don't cover the per-channel (per-row/per-column) quantization cases. The implementation in onnx_frontend.py includes logic to handle 1-D zero-point tensors, but this is currently untested.

Please consider adding a new test case to cover per-channel zero points to ensure the reshape and broadcasting logic is correct. For example:

def test_matmulinteger_per_channel_zp():
    A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
    B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
    a_zp = np.arange(4, dtype=np.int8) # per-row
    b_zp = np.arange(6, dtype=np.int8) # per-column
    model = _make_matmulinteger_model(
        [4, 8],
        [8, 6],
        np.int8,
        np.int8,
        a_zp_array=a_zp,
        b_zp_array=b_zp,
    )
    check_correctness(model, inputs={"A": A, "B": B}, opset=10)

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of follow-ups:

  • Please avoid using dim.value here. It will break on symbolic shapes. relax.op.expand_dims(a_zp, axis=-1) would be simpler and more robust than extracting concrete shape values for the reshape. The same applies to b_zp if it has a similar path.

  • Please add a test for 1-D zero points as well. Right now all tests use scalar zero points, so the reshape/broadcast path for per-row or per-column zero points is not covered.

…er, add per-channel zp tests

Changes based on reviewer feedback:

- Replace dim.value shape extraction with relax.op.expand_dims for zero-point
  reshaping, which is safe for symbolic/dynamic shapes unlike extracting
  concrete dimension values via struct_info.shape

- Fix _make_matmulinteger_model helper: zero-points were incorrectly added to
  both graph_inputs and initializers, causing duplicate Relax variable names
  and malformed IR. Now added as initializers only.

- Fix node_inputs placeholder logic: empty string for absent a_zero_point is
  now only appended when b_zero_point is present, avoiding a spurious 3-input
  node when both zero-points are absent.

- Add test_matmulinteger_per_channel_zp: verifies 1-D per-row a_zero_point
  [M] and per-col b_zero_point [N], exercising the expand_dims path in the
  converter. Runs TVM-only against a NumPy reference since ORT CPU does not
  support per-row a_zero_point despite the ONNX spec allowing it.

- Add test_matmulinteger_per_channel_zp_ort_limitation: xfail test that
  documents the ORT CPU kernel limitation (strict=True so we are alerted
  if ORT ever fixes this).

- Mark int8/uint8 mixed dtype test as xfail(strict=False) since some older
  ORT versions do not implement this combination, but newer versions do.

Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
@OmarAzizi
Copy link
Copy Markdown
Contributor Author

OmarAzizi commented Mar 29, 2026

Thank you @tlopex for the review, I addressed everything in the follow-up commit.

One thing worth noting. Two ORT limitations surfaced during testing:

1. Per-row a_zero_point (shape [M])
The ONNX spec allows this, but ORT's kernel unconditionally rejects it. Our converter handles it correctly, so the test verifies TVM against a NumPy reference directly. Also, I added a separate test marked xfail(strict=True) to document the ORT gap and ensure that we get alerted if ORT ever fixes it.

2. Mixed int8/uint8 dtype — xfail(strict=False)
The CI's ORT version doesn't support this combination, but newer ORT does (shows as XPASS locally). strict=False lets it pass on both old and new ORT without breaking CI.

Edit: I mentioned that the ORT limitation is in the CPU kernel in the commit message, but after testing, it fails on the GPU as well. The error message looks something like this:

FAILED test_matmulinteger_per_channel_zp

tests/python/relax/test_frontend_onnx.py:126: in check_correctness
    ort_output = ort_session.run([], inputs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^

E   onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException:
    [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned
    while running MatMulInteger node.

    MatmulInteger: input1 zero point must be a scalar or 1D tensor of size 1

@OmarAzizi OmarAzizi requested a review from tlopex March 29, 2026 15:34
Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Could you resolve the conflicts?

@OmarAzizi
Copy link
Copy Markdown
Contributor Author

LGTM! Could you resolve the conflicts?

Thank you, Done

@tlopex tlopex merged commit 8597d21 into apache:main Mar 29, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Tracking Issue][ONNX] Complete missing and limited operators in ONNX frontend

2 participants