Skip to content

Commit f392a70

Browse files
YifanShenSZyifan_shen3
andauthored
[PyTorch] [ExecuTorch] Support New ExecuTorch Models (#2078)
* support new executorch models except llama2; inception v3 is now supported as well * add frontend info to context, so different input schema of torch script and exir can be handled properly * revert string type hint: caused CI failure in https://gitlab.com/coremltools1/coremltools/-/jobs/5790626756 --------- Co-authored-by: yifan_shen3 <[email protected]>
1 parent f8947af commit f392a70

File tree

7 files changed

+353
-78
lines changed

7 files changed

+353
-78
lines changed

coremltools/converters/mil/frontend/torch/converter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from .._utils import get_output_names
2323
from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode
24-
from .ops import convert_nodes
24+
from .ops import TorchFrontend, convert_nodes
2525
from .quantization_ops import _dequantized_weight
2626
from .torch_op_registry import _TORCH_OPS_REGISTRY
2727
from .torchir_passes import (
@@ -194,8 +194,13 @@ class TranscriptionContext:
194194
context when stepping out.
195195
"""
196196

197-
def __init__(self, name: Optional[str] = None) -> None:
197+
def __init__(
198+
self,
199+
name: Optional[str] = None,
200+
frontend: TorchFrontend = TorchFrontend.TORCHSCRIPT,
201+
) -> None:
198202
self.name = name if name else ""
203+
self.frontend = frontend
199204
self._current_graph = [{}]
200205
self._torch_graph = None
201206
self._quant_context = QuantizationContext(self)
@@ -346,6 +351,7 @@ def __init__(
346351
self._prog = Program()
347352

348353
if isinstance(loaded_model, torch.jit.ScriptModule):
354+
self.context.frontend = TorchFrontend.TORCHSCRIPT
349355
self.graph, self.params_dict, self.buffer_dict = InternalTorchIRGraph.from_torchscript(
350356
torchscript=loaded_model, input_values=self.inputs, cut_at_symbols=cut_at_symbols
351357
)
@@ -363,6 +369,7 @@ def __init__(
363369
p(self.graph)
364370

365371
elif _HAS_TORCH_EXPORT_API and isinstance(loaded_model, ExportedProgram):
372+
self.context.frontend = TorchFrontend.EDGEIR
366373
self.graph = InternalTorchIRGraph.from_edgeir(edgeir=loaded_model)
367374
self.params_dict, self.buffer_dict = None, None
368375
else:

0 commit comments

Comments
 (0)