2121
2222from .._utils import get_output_names
2323from .internal_graph import InternalTorchIRGraph , InternalTorchIRNode
24- from .ops import convert_nodes
24+ from .ops import TorchFrontend , convert_nodes
2525from .quantization_ops import _dequantized_weight
2626from .torch_op_registry import _TORCH_OPS_REGISTRY
2727from .torchir_passes import (
@@ -194,8 +194,13 @@ class TranscriptionContext:
194194 context when stepping out.
195195 """
196196
197- def __init__ (self , name : Optional [str ] = None ) -> None :
197+ def __init__ (
198+ self ,
199+ name : Optional [str ] = None ,
200+ frontend : TorchFrontend = TorchFrontend .TORCHSCRIPT ,
201+ ) -> None :
198202 self .name = name if name else ""
203+ self .frontend = frontend
199204 self ._current_graph = [{}]
200205 self ._torch_graph = None
201206 self ._quant_context = QuantizationContext (self )
@@ -346,6 +351,7 @@ def __init__(
346351 self ._prog = Program ()
347352
348353 if isinstance (loaded_model , torch .jit .ScriptModule ):
354+ self .context .frontend = TorchFrontend .TORCHSCRIPT
349355 self .graph , self .params_dict , self .buffer_dict = InternalTorchIRGraph .from_torchscript (
350356 torchscript = loaded_model , input_values = self .inputs , cut_at_symbols = cut_at_symbols
351357 )
@@ -363,6 +369,7 @@ def __init__(
363369 p (self .graph )
364370
365371 elif _HAS_TORCH_EXPORT_API and isinstance (loaded_model , ExportedProgram ):
372+ self .context .frontend = TorchFrontend .EDGEIR
366373 self .graph = InternalTorchIRGraph .from_edgeir (edgeir = loaded_model )
367374 self .params_dict , self .buffer_dict = None , None
368375 else :
0 commit comments