[Frontend][ONNX] Add If operator support to Relax ONNX frontend#18946
[Frontend][ONNX] Add If operator support to Relax ONNX frontend#18946tlopex merged 2 commits intoapache:mainfrom
Conversation
Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request implements support for the ONNX If operator in the Relax frontend. Key changes include the addition of subgraph conversion logic, handling control flow by conditionally disabling dataflow blocks in the main function, and updating attribute parsing to support nested graphs. Feedback suggests improving code conciseness by using enumerate in loops, removing redundant imports in the new test cases, and refactoring similar tests into a single parameterized test to reduce duplication.
| for k, i in zip(outputs, range(len(outputs))): | ||
| self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i)) |
There was a problem hiding this comment.
For better readability and conciseness, you can use enumerate here instead of zip with range. This also applies to a similar loop in _convert_subgraph on line 4615.
| for k, i in zip(outputs, range(len(outputs))): | |
| self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i)) | |
| for i, k in enumerate(outputs): | |
| self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i)) |
| for k, i in zip(outputs, range(len(outputs))): | ||
| self._nodes[k] = op[i] |
| import numpy as np | ||
| from onnx import TensorProto, helper |
| def test_if(): | ||
| """Test ONNX If operator with scalar bool condition.""" | ||
| import numpy as np | ||
| from onnx import TensorProto, helper | ||
|
|
||
| x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3]) | ||
| cond_info = helper.make_tensor_value_info("cond", TensorProto.BOOL, []) | ||
| result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, [3]) | ||
|
|
||
| # then branch: x * 2.0 | ||
| two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0]) | ||
| then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"]) | ||
| then_out_info = helper.make_tensor_value_info("then_out", TensorProto.FLOAT, [3]) | ||
| then_graph = helper.make_graph( | ||
| [then_mul], "then_graph", [], [then_out_info], initializer=[two] | ||
| ) | ||
|
|
||
| # else branch: x * 3.0 | ||
| three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0]) | ||
| else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"]) | ||
| else_out_info = helper.make_tensor_value_info("else_out", TensorProto.FLOAT, [3]) | ||
| else_graph = helper.make_graph( | ||
| [else_mul], "else_graph", [], [else_out_info], initializer=[three] | ||
| ) | ||
|
|
||
| if_node = helper.make_node( | ||
| "If", | ||
| inputs=["cond"], | ||
| outputs=["result"], | ||
| then_branch=then_graph, | ||
| else_branch=else_graph, | ||
| ) | ||
| main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], [result_info]) | ||
| model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 13)]) | ||
|
|
||
| x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) | ||
|
|
||
| check_correctness(model, inputs={"cond": np.array(True), "x": x_data}) | ||
| check_correctness(model, inputs={"cond": np.array(False), "x": x_data}) | ||
|
|
||
|
|
||
| def test_if_tensor_condition(): | ||
| """Test If operator where condition is a 1-element tensor, not a scalar.""" | ||
| import numpy as np | ||
| from onnx import TensorProto, helper | ||
|
|
||
| cond_info = helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]) | ||
| x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3]) | ||
| result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, [3]) | ||
|
|
||
| two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0]) | ||
| then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"]) | ||
| then_out_info = helper.make_tensor_value_info("then_out", TensorProto.FLOAT, [3]) | ||
| then_graph = helper.make_graph( | ||
| [then_mul], "then_graph", [], [then_out_info], initializer=[two] | ||
| ) | ||
|
|
||
| three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0]) | ||
| else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"]) | ||
| else_out_info = helper.make_tensor_value_info("else_out", TensorProto.FLOAT, [3]) | ||
| else_graph = helper.make_graph( | ||
| [else_mul], "else_graph", [], [else_out_info], initializer=[three] | ||
| ) | ||
|
|
||
| if_node = helper.make_node( | ||
| "If", inputs=["cond"], outputs=["result"], then_branch=then_graph, else_branch=else_graph | ||
| ) | ||
| main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], [result_info]) | ||
| model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 13)]) | ||
|
|
||
| x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) | ||
| check_correctness(model, inputs={"cond": np.array([True]), "x": x_data}) | ||
| check_correctness(model, inputs={"cond": np.array([False]), "x": x_data}) |
There was a problem hiding this comment.
The tests test_if and test_if_tensor_condition are very similar and contain a lot of duplicated code. You can merge them into a single parameterized test using pytest.mark.parametrize to improve maintainability and reduce redundancy. Here's a suggested implementation:
@pytest.mark.parametrize(
"cond_info, cond_true, cond_false",
[
(
helper.make_tensor_value_info("cond", TensorProto.BOOL, []),
np.array(True),
np.array(False),
),
(
helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]),
np.array([True]),
np.array([False]),
),
],
ids=["scalar_condition", "tensor_condition"],
)
def test_if(cond_info, cond_true, cond_false):
"""Test ONNX If operator with scalar and tensor bool conditions."""
x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, [3])
# then branch: x * 2.0
two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
then_out_info = helper.make_tensor_value_info("then_out", TensorProto.FLOAT, [3])
then_graph = helper.make_graph(
[then_mul], "then_graph", [], [then_out_info], initializer=[two]
)
# else branch: x * 3.0
three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
else_out_info = helper.make_tensor_value_info("else_out", TensorProto.FLOAT, [3])
else_graph = helper.make_graph(
[else_mul], "else_graph", [], [else_out_info], initializer=[three]
)
if_node = helper.make_node(
"If",
inputs=["cond"],
outputs=["result"],
then_branch=then_graph,
else_branch=else_graph,
)
main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], [result_info])
model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 13)])
x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
check_correctness(model, inputs={"cond": cond_true, "x": x_data})
check_correctness(model, inputs={"cond": cond_false, "x": x_data})This would replace both test_if and test_if_tensor_condition.
tlopex
left a comment
There was a problem hiding this comment.
A few follow-up comments here:
-
Nested
Ifdoes not seem to be supported yet._convert_subgraphcalls_convert_operator(...)directly for each node, but unlike_construct_nodes, it does not special-caseop_name == "If". So a nestedIfwould end up inIf._impl_v1and raiseNotImplementedError. Maybe we should either handle this recursively or document that nestedIfis currently unsupported. -
In
_convert_subgraph, thisexcept Exceptionfeels too broad for control flow:
try:
_ = op.struct_info
except Exception:
op = bb.normalize(op)
It can hide unrelated errors. It seems safer to check hasattr(op, "struct_info") / op.struct_info is not None, or catch a more specific exception type.
-
The
Ifconverter class also looks like dead code right now. It seems to exist only to get through the "supported ops" check, but_impl_v1always raisesNotImplementedError. If that is the only purpose, maybe it would be cleaner to add"If"directly to the supported ops set instead of registering a stub converter. -
Minor style nit:
contextlibis imported insidefrom_onnx, but it is a stdlib import, so it would be better to move it to the file-level imports.
| with self.bb.dataflow() as df: # pylint: disable=invalid-name, unused-variable | ||
| with contextlib.ExitStack() as stack: | ||
| if not has_if: | ||
| stack.enter_context(self.bb.dataflow()) |
There was a problem hiding this comment.
This turns off dataflow for the entire model whenever there is an If, which seems a bit too conservative. Maybe we can just break the dataflow block around the If node instead of disabling it globally.
| for k, i in zip(outputs, range(len(outputs))): | ||
| self._nodes[k] = self.bb.emit(relax.TupleGetItem(if_result, i)) |
| Outer-scope nodes are visible because we copy self._nodes into the | ||
| local lookup table before processing. | ||
| """ | ||
| outer_nodes = dict(self._nodes) |
There was a problem hiding this comment.
self._nodes is restored only on the normal path here. If subgraph conversion throws, the outer state is left mutated. Seems safer to wrap this in try/finally so self._nodes always gets restored.
| convert_class = convert_map[op_name] | ||
| op_function = convert_class.get_converter(opset) | ||
| sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params]) | ||
| sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params, self]) |
There was a problem hiding this comment.
The third element self is never accessed by any converter. If you need the importer accessible to converters, pass it as an explicit keyword argument or a named container, not hidden in a list.
Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
|
@tlopex Thanks for the review! Addressed all the smaller feedback points in the follow-up commit. Additionally, I included the handling of nested if as requested, along with a test for it.
Note: I tried The property itself raises before returning Still open: Disabling dataflow globally when I attempted to close the dataflow block locally around the with R.dataflow():
if cond: # If inside dataflow block (invalid)
with R.dataflow(): # branch nodes incorrectly emitted as dataflow vars
lv = R.multiply(...)
R.output()The current approach of skipping the dataflow block for the entire function when |
Summary
This PR implements the ONNX
Ifoperator in the Relax ONNX frontend. TheIfoperator enables conditional branching in ONNX models, where a boolean condition selects between two subgraph branches (then_branchandelse_branch) at runtime. This is required for any model with runtime-dependent execution paths.Closes #18945 (Tier 1 —
Ifoperator)Implementation Notes
The main challenge is that
relax.Ifcannot be emitted inside a dataflow block, which is how the ONNX frontend normally builds the entire graph. To handle this, when the graph contains anIfnode, the function body is built as a regular binding block instead — matching the approach used by the PyTorch Relax frontend fortorch.cond.Each branch is an ONNX subgraph that can reference values from the outer graph. A new
_convert_subgraphmethod handles converting these subgraphs into Relax expressions, making outer-scope values available to the branch while ensuring branch-local bindings don't leak back to the parent graph.Why
relax.Ifcannot live inside a dataflow blockDataflow blocks in Relax carry a semantic guarantee: every operation inside them must be pure and side-effect-free with no control flow. This allows the compiler to treat the entire block as a static computational graph for optimizations like operator fusion and constant folding. An
Ifnode breaks this guarantee by introducing runtime-dependent branching, so Relax's well-formedness checker explicitly forbids it. I discovered this when the checker raised:The fix — skipping the dataflow block when the graph contains an
Ifnode — mirrors exactly how the PyTorch Relax frontend handlestorch.cond.Known Limitations
Dataflow block: Models whose top-level graph contains an
Ifnode are built without a dataflow block, which may affect downstream optimisation passes that rely on dataflow block structure.Tests
Four new tests covering: scalar and tensor conditions, condition computed from another op, and multiple branch outputs. All verified against onnxruntime via
check_correctness.