Skip to content

Commit 62ef5a2

Browse files
YifanShenSZyifan_shen3
andauthored
[PyTorch] Add API to Query if a PyTorch Op is Supported in Core ML (#2081)
add API to query if a PyTorch op is supported in Core ML --------- Co-authored-by: yifan_shen3 <[email protected]>
1 parent 600fda7 commit 62ef5a2

File tree

3 files changed

+114
-2
lines changed

3 files changed

+114
-2
lines changed

coremltools/converters/mil/frontend/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
from . import ops, quantization_ops
1212
from .dialect_ops import (torch_tensor_assign, torch_upsample_bilinear,
1313
torch_upsample_nearest_neighbor)
14-
from .torch_op_registry import register_torch_op
14+
from .torch_op_registry import register_torch_op, is_torch_fx_node_supported

coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import itertools
77
import os
8+
from unittest.mock import patch
89

910
import numpy as np
1011
import pytest
@@ -19,6 +20,7 @@
1920
)
2021
from coremltools.converters.mil.frontend.torch.test.testing_utils import _copy_input_data
2122
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
23+
TorchOpsRegistry,
2224
_TORCH_OPS_REGISTRY,
2325
register_torch_op,
2426
)
@@ -264,6 +266,71 @@ def test_func_dummy(context, inputs):
264266
# Cleanup the test
265267
del _TORCH_OPS_REGISTRY.name_to_func_mapping["test_func_dummy"]
266268

269+
270+
@pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND)
271+
class TestFxNodeSupport:
272+
"""
273+
The API ``ct.converters.mil.frontend.torch.is_torch_fx_node_supported`` is used
274+
by 3rd-party code ExecuTorch: https://github.com/pytorch/executorch/pull/1415,
275+
so we cannot break it
276+
"""
277+
278+
@staticmethod
279+
def test_simple_case():
280+
class Model(torch.nn.Module):
281+
def forward(self, a, x, b):
282+
y = torch.mm(a, x)
283+
z = y + b
284+
a.sub_(z)
285+
y = torch.mm(a, x)
286+
z = y + b
287+
return z
288+
289+
model = Model()
290+
model.eval()
291+
symbolic_traced = torch.fx.symbolic_trace(model)
292+
293+
for node in symbolic_traced.graph.nodes:
294+
# There are many types of torch fx node,
295+
# we only support "call_function" node for now
296+
if node.op == "call_function":
297+
# All PyTorch ops in the example model are supported, so they should all return true
298+
assert ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
299+
# Other types of torch fx node are not supported
300+
else:
301+
assert not ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
302+
303+
@staticmethod
304+
def test_unsupported_op():
305+
class Model(torch.nn.Module):
306+
def forward(self, x, y):
307+
z = x + y
308+
return torch.nn.functional.softmax(z)
309+
310+
model = Model()
311+
model.eval()
312+
symbolic_traced = torch.fx.symbolic_trace(model)
313+
314+
# Mock our torch ops registry, pretending that only "add" is supported
315+
with patch.object(
316+
TorchOpsRegistry,
317+
"__contains__",
318+
side_effect=(lambda op_name: op_name == "add"),
319+
):
320+
for node in symbolic_traced.graph.nodes:
321+
# There are many types of torch fx node,
322+
# we only support "call_function" node for now
323+
if node.op == "call_function":
324+
# Only "add" is supported
325+
assert (
326+
(node.target.__name__.lower() == "add")
327+
== ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
328+
)
329+
# Other types of torch fx node are not supported
330+
else:
331+
assert not ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
332+
333+
267334
#################################################################################
268335
# Note: Starting from here, all of the following tests are also used as examples
269336
# in https://coremltools.readme.io/docs as a reference.

coremltools/converters/mil/frontend/torch/torch_op_registry.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
# Use of this source code is governed by a BSD-3-clause license that can be
44
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
55

6-
from typing import Callable
6+
from typing import Callable, List
77

8+
import torch
9+
10+
from coremltools import _logger as logger
811
from coremltools.models._deprecation import deprecated as _deprecated
912

1013

@@ -127,3 +130,45 @@ def func_wrapper(func):
127130
# decorator called without argument
128131
return func_wrapper
129132
return func_wrapper(_func)
133+
134+
135+
def is_torch_fx_node_supported(torch_fx_node: torch.fx.Node) -> bool:
136+
# There are many types of torch fx node:
137+
# 1. call_function
138+
# 2. call_module
139+
# 3. call_method
140+
# 4. get_attr
141+
# 5. placeholder
142+
# 6. output
143+
# ...
144+
# Only "call_*" nodes contain PyTorch ops,
145+
# among them we only support "call_function" node for now
146+
if torch_fx_node.op != "call_function":
147+
logger.warning(
148+
"For now, among all types of torch fx nodes, CoreML only supports call_function node"
149+
)
150+
return False
151+
152+
# Get the target in torch fx node, and canonicalize it to lower-case string
153+
torch_fx_node_target = torch_fx_node.target
154+
if isinstance(torch_fx_node_target, str):
155+
torch_fx_node_target_name = torch_fx_node_target.lower()
156+
else:
157+
torch_fx_node_target_name = torch_fx_node.target.__name__.lower()
158+
# Since we are only dealing with "call_function" node,
159+
# the contained PyTorch op must be functional, i.e. not in-place
160+
assert (
161+
not torch_fx_node_target_name.endswith("_")
162+
), (
163+
"For now, since CoreML only supports call_function torch fx node, "
164+
"all ops should be functional, i.e. there should not be any in-place op"
165+
)
166+
# Target name may or may not contain prefix "aten.":
167+
# 1. For usual fx node, target is a PyTorch function, i.e. no prefix
168+
# 2. For executorch exported fx node, target is executorch.exir.dialects.edge._ops.EdgeOp,
169+
# whose name has format "aten.xx.yy"
170+
_ATEN_NODE_PREFIX = "aten."
171+
if torch_fx_node_target_name.startswith(_ATEN_NODE_PREFIX):
172+
torch_fx_node_target_name = torch_fx_node_target_name[len(_ATEN_NODE_PREFIX):]
173+
174+
return torch_fx_node_target_name in _TORCH_OPS_REGISTRY

0 commit comments

Comments
 (0)