|
5 | 5 |
|
6 | 6 | import itertools |
7 | 7 | import os |
| 8 | +from unittest.mock import patch |
8 | 9 |
|
9 | 10 | import numpy as np |
10 | 11 | import pytest |
|
19 | 20 | ) |
20 | 21 | from coremltools.converters.mil.frontend.torch.test.testing_utils import _copy_input_data |
21 | 22 | from coremltools.converters.mil.frontend.torch.torch_op_registry import ( |
| 23 | + TorchOpsRegistry, |
22 | 24 | _TORCH_OPS_REGISTRY, |
23 | 25 | register_torch_op, |
24 | 26 | ) |
@@ -264,6 +266,71 @@ def test_func_dummy(context, inputs): |
264 | 266 | # Cleanup the test |
265 | 267 | del _TORCH_OPS_REGISTRY.name_to_func_mapping["test_func_dummy"] |
266 | 268 |
|
| 269 | + |
| 270 | +@pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) |
| 271 | +class TestFxNodeSupport: |
| 272 | + """ |
| 273 | + The API ``ct.converters.mil.frontend.torch.is_torch_fx_node_supported`` is used |
| 274 | + by 3rd-party code ExecuTorch: https://github.com/pytorch/executorch/pull/1415, |
| 275 | + so we cannot break it |
| 276 | + """ |
| 277 | + |
| 278 | + @staticmethod |
| 279 | + def test_simple_case(): |
| 280 | + class Model(torch.nn.Module): |
| 281 | + def forward(self, a, x, b): |
| 282 | + y = torch.mm(a, x) |
| 283 | + z = y + b |
| 284 | + a.sub_(z) |
| 285 | + y = torch.mm(a, x) |
| 286 | + z = y + b |
| 287 | + return z |
| 288 | + |
| 289 | + model = Model() |
| 290 | + model.eval() |
| 291 | + symbolic_traced = torch.fx.symbolic_trace(model) |
| 292 | + |
| 293 | + for node in symbolic_traced.graph.nodes: |
| 294 | + # There are many types of torch fx node, |
| 295 | + # we only support "call_function" node for now |
| 296 | + if node.op == "call_function": |
| 297 | + # All PyTorch ops in the example model are supported, so they should all return true |
| 298 | + assert ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node) |
| 299 | + # Other types of torch fx node are not supported |
| 300 | + else: |
| 301 | + assert not ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node) |
| 302 | + |
| 303 | + @staticmethod |
| 304 | + def test_unsupported_op(): |
| 305 | + class Model(torch.nn.Module): |
| 306 | + def forward(self, x, y): |
| 307 | + z = x + y |
| 308 | + return torch.nn.functional.softmax(z) |
| 309 | + |
| 310 | + model = Model() |
| 311 | + model.eval() |
| 312 | + symbolic_traced = torch.fx.symbolic_trace(model) |
| 313 | + |
| 314 | + # Mock our torch ops registry, pretending that only "add" is supported |
| 315 | + with patch.object( |
| 316 | + TorchOpsRegistry, |
| 317 | + "__contains__", |
| 318 | + side_effect=(lambda op_name: op_name == "add"), |
| 319 | + ): |
| 320 | + for node in symbolic_traced.graph.nodes: |
| 321 | + # There are many types of torch fx node, |
| 322 | + # we only support "call_function" node for now |
| 323 | + if node.op == "call_function": |
| 324 | + # Only "add" is supported |
| 325 | + assert ( |
| 326 | + (node.target.__name__.lower() == "add") |
| 327 | + == ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node) |
| 328 | + ) |
| 329 | + # Other types of torch fx node are not supported |
| 330 | + else: |
| 331 | + assert not ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node) |
| 332 | + |
| 333 | + |
267 | 334 | ################################################################################# |
268 | 335 | # Note: Starting from here, all of the following tests are also used as examples |
269 | 336 | # in https://coremltools.readme.io/docs as a reference. |
|
0 commit comments