Skip to content

[Frontend][ONNX] Add If operator support to Relax ONNX frontend#18946

Merged
tlopex merged 2 commits intoapache:mainfrom
OmarAzizi:onnx-frontend-if-operator
Mar 28, 2026
Merged

[Frontend][ONNX] Add If operator support to Relax ONNX frontend#18946
tlopex merged 2 commits intoapache:mainfrom
OmarAzizi:onnx-frontend-if-operator

Conversation

@OmarAzizi
Copy link
Copy Markdown
Contributor

@OmarAzizi OmarAzizi commented Mar 28, 2026

Summary

This PR implements the ONNX If operator in the Relax ONNX frontend. The If operator enables conditional branching in ONNX models, where a boolean condition selects between two subgraph branches (then_branch and else_branch) at runtime. This is required for any model with runtime-dependent execution paths.

Closes #18945 (Tier 1 — If operator)

Implementation Notes

  • The main challenge is that relax.If cannot be emitted inside a dataflow block, which is how the ONNX frontend normally builds the entire graph. To handle this, when the graph contains an If node, the function body is built as a regular binding block instead — matching the approach used by the PyTorch Relax frontend for torch.cond.

  • Each branch is an ONNX subgraph that can reference values from the outer graph. A new _convert_subgraph method handles converting these subgraphs into Relax expressions, making outer-scope values available to the branch while ensuring branch-local bindings don't leak back to the parent graph.

Why relax.If cannot live inside a dataflow block

Dataflow blocks in Relax carry a semantic guarantee: every operation inside them must be pure and side-effect-free with no control flow. This allows the compiler to treat the entire block as a static computational graph for optimizations like operator fusion and constant folding. An If node breaks this guarantee by introducing runtime-dependent branching, so Relax's well-formedness checker explicitly forbids it. I discovered this when the checker raised:

This IR is not well-formed: If nodes are not allowed to appear in dataflow blocks.

The fix — skipping the dataflow block when the graph contains an If node — mirrors exactly how the PyTorch Relax frontend handles torch.cond.

Known Limitations

Dataflow block: Models whose top-level graph contains an If node are built without a dataflow block, which may affect downstream optimisation passes that rely on dataflow block structure.

Tests

Four new tests covering: scalar and tensor conditions, condition computed from another op, and multiple branch outputs. All verified against onnxruntime via check_correctness.

Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for the ONNX If operator in the Relax frontend. Key changes include the addition of subgraph conversion logic, handling control flow by conditionally disabling dataflow blocks in the main function, and updating attribute parsing to support nested graphs. Feedback suggests improving code conciseness by using enumerate in loops, removing redundant imports in the new test cases, and refactoring similar tests into a single parameterized test to reduce duplication.

Comment on lines +4428 to +4429
for k, i in zip(outputs, range(len(outputs))):
self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and conciseness, you can use enumerate here instead of zip with range. This also applies to a similar loop in _convert_subgraph on line 4615.

Suggested change
for k, i in zip(outputs, range(len(outputs))):
self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i))
for i, k in enumerate(outputs):
self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense

Comment on lines +4615 to +4616
for k, i in zip(outputs, range(len(outputs))):
self._nodes[k] = op[i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and conciseness, you can use enumerate here instead of zip with range.

Suggested change
for k, i in zip(outputs, range(len(outputs))):
self._nodes[k] = op[i]
for i, k in enumerate(outputs):
self._nodes[k] = op[i]

Comment on lines +4207 to +4208
import numpy as np
from onnx import TensorProto, helper
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These imports are already present at the top of the file. You can remove these redundant imports from this and the other new test functions (test_if_tensor_condition, test_if_computed_condition, test_if_multiple_outputs).

Comment on lines +4205 to +4277
def test_if():
"""Test ONNX If operator with scalar bool condition."""
import numpy as np
from onnx import TensorProto, helper

x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
cond_info = helper.make_tensor_value_info("cond", TensorProto.BOOL, [])
result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, [3])

# then branch: x * 2.0
two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
then_out_info = helper.make_tensor_value_info("then_out", TensorProto.FLOAT, [3])
then_graph = helper.make_graph(
[then_mul], "then_graph", [], [then_out_info], initializer=[two]
)

# else branch: x * 3.0
three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
else_out_info = helper.make_tensor_value_info("else_out", TensorProto.FLOAT, [3])
else_graph = helper.make_graph(
[else_mul], "else_graph", [], [else_out_info], initializer=[three]
)

if_node = helper.make_node(
"If",
inputs=["cond"],
outputs=["result"],
then_branch=then_graph,
else_branch=else_graph,
)
main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], [result_info])
model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 13)])

x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)

check_correctness(model, inputs={"cond": np.array(True), "x": x_data})
check_correctness(model, inputs={"cond": np.array(False), "x": x_data})


def test_if_tensor_condition():
"""Test If operator where condition is a 1-element tensor, not a scalar."""
import numpy as np
from onnx import TensorProto, helper

cond_info = helper.make_tensor_value_info("cond", TensorProto.BOOL, [1])
x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, [3])

two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
then_out_info = helper.make_tensor_value_info("then_out", TensorProto.FLOAT, [3])
then_graph = helper.make_graph(
[then_mul], "then_graph", [], [then_out_info], initializer=[two]
)

three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
else_out_info = helper.make_tensor_value_info("else_out", TensorProto.FLOAT, [3])
else_graph = helper.make_graph(
[else_mul], "else_graph", [], [else_out_info], initializer=[three]
)

if_node = helper.make_node(
"If", inputs=["cond"], outputs=["result"], then_branch=then_graph, else_branch=else_graph
)
main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], [result_info])
model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 13)])

x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
check_correctness(model, inputs={"cond": np.array([True]), "x": x_data})
check_correctness(model, inputs={"cond": np.array([False]), "x": x_data})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tests test_if and test_if_tensor_condition are very similar and contain a lot of duplicated code. You can merge them into a single parameterized test using pytest.mark.parametrize to improve maintainability and reduce redundancy. Here's a suggested implementation:

@pytest.mark.parametrize(
    "cond_info, cond_true, cond_false",
    [
        (
            helper.make_tensor_value_info("cond", TensorProto.BOOL, []),
            np.array(True),
            np.array(False),
        ),
        (
            helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]),
            np.array([True]),
            np.array([False]),
        ),
    ],
    ids=["scalar_condition", "tensor_condition"],
)
def test_if(cond_info, cond_true, cond_false):
    """Test ONNX If operator with scalar and tensor bool conditions."""
    x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
    result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, [3])

    # then branch: x * 2.0
    two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
    then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
    then_out_info = helper.make_tensor_value_info("then_out", TensorProto.FLOAT, [3])
    then_graph = helper.make_graph(
        [then_mul], "then_graph", [], [then_out_info], initializer=[two]
    )

    # else branch: x * 3.0
    three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
    else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
    else_out_info = helper.make_tensor_value_info("else_out", TensorProto.FLOAT, [3])
    else_graph = helper.make_graph(
        [else_mul], "else_graph", [], [else_out_info], initializer=[three]
    )

    if_node = helper.make_node(
        "If",
        inputs=["cond"],
        outputs=["result"],
        then_branch=then_graph,
        else_branch=else_graph,
    )
    main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], [result_info])
    model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 13)])

    x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)

    check_correctness(model, inputs={"cond": cond_true, "x": x_data})
    check_correctness(model, inputs={"cond": cond_false, "x": x_data})

This would replace both test_if and test_if_tensor_condition.

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few follow-up comments here:

  • Nested If does not seem to be supported yet. _convert_subgraph calls _convert_operator(...) directly for each node, but unlike _construct_nodes, it does not special-case op_name == "If". So a nested If would end up in If._impl_v1 and raise NotImplementedError. Maybe we should either handle this recursively or document that nested If is currently unsupported.

  • In _convert_subgraph, this except Exception feels too broad for control flow:

  try:
      _ = op.struct_info
  except Exception:
      op = bb.normalize(op)

It can hide unrelated errors. It seems safer to check hasattr(op, "struct_info") / op.struct_info is not None, or catch a more specific exception type.

  • The If converter class also looks like dead code right now. It seems to exist only to get through the "supported ops" check, but _impl_v1 always raises NotImplementedError. If that is the only purpose, maybe it would be cleaner to add "If" directly to the supported ops set instead of registering a stub converter.

  • Minor style nit: contextlib is imported inside from_onnx, but it is a stdlib import, so it would be better to move it to the file-level imports.

with self.bb.dataflow() as df: # pylint: disable=invalid-name, unused-variable
with contextlib.ExitStack() as stack:
if not has_if:
stack.enter_context(self.bb.dataflow())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This turns off dataflow for the entire model whenever there is an If, which seems a bit too conservative. Maybe we can just break the dataflow block around the If node instead of disabling it globally.

Comment on lines +4428 to +4429
for k, i in zip(outputs, range(len(outputs))):
self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense

Outer-scope nodes are visible because we copy self._nodes into the
local lookup table before processing.
"""
outer_nodes = dict(self._nodes)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._nodes is restored only on the normal path here. If subgraph conversion throws, the outer state is left mutated. Seems safer to wrap this in try/finally so self._nodes always gets restored.

convert_class = convert_map[op_name]
op_function = convert_class.get_converter(opset)
sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params])
sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params, self])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The third element self is never accessed by any converter. If you need the importer accessible to converters, pass it as an explicit keyword argument or a named container, not hidden in a list.

Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
@OmarAzizi
Copy link
Copy Markdown
Contributor Author

OmarAzizi commented Mar 28, 2026

@tlopex Thanks for the review! Addressed all the smaller feedback points in the follow-up commit. Additionally, I included the handling of nested if as requested, along with a test for it.

In _convert_subgraph, this except Exception feels too broad for control flow. It seems safer to check hasattr(op, "struct_info") / op.struct_info is not None, or catch a more specific exception type.

Note: I tried hasattr(op, "struct_info") but it doesn't work here. struct_info is a property on the class, so hasattr always returns True, but accessing it still raises when unpopulated:

python/tvm/ir/expr.py:62: in struct_info
    return _ffi_api.ExprStructInfo(self)
include/tvm/relax/struct_info.h:398: in tvm::relax::StructInfo tvm::relax::GetStructInfo(const Expr&)
    TVM_FFI_ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr";
E   tvm.error.InternalError: Check failed: (ptr) is false: The struct_info is not populated, check if you have normalized the expr

The property itself raises before returning None, so the is None check is never reached. Instead, I now catch tvm.error.InternalError specifically, which is exactly what TVM raises in this case.


Still open: Disabling dataflow globally when If is present

I attempted to close the dataflow block locally around the If node and reopen it afterward, but hit a fundamental issue: _convert_subgraph needs to call bb.normalize on branch expressions to populate struct_info. When closing the dataflow block before converting the branches, normalize fails. When left it open, the branch nodes get emitted into the active dataflow block — producing invalid IR like this:

with R.dataflow():
    if cond:               # If inside dataflow block (invalid)
        with R.dataflow(): # branch nodes incorrectly emitted as dataflow vars
            lv = R.multiply(...)
            R.output()

The current approach of skipping the dataflow block for the entire function when If is present avoids this and produces correct IR. It matches the same pattern the PyTorch Relax frontend uses for torch.cond. Happy to explore alternatives if you have a suggestion.

@OmarAzizi OmarAzizi requested a review from tlopex March 28, 2026 11:50
Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thank you!

@tlopex tlopex merged commit 52b5d55 into apache:main Mar 28, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Tracking Issue][ONNX] Complete missing and limited operators in ONNX frontend

2 participants