@@ -65,22 +65,6 @@ def _get_layer_attributes(
6565 )
6666 return None
6767
68- def _map_fx_unique_metatypes (node : torch .fx .Node , metatype : om .OperatorMetatype ) -> om .OperatorMetatype :
69- """
70- Attempts to retrieve correct subtype for the given node.
71-
72- :param node: Given node.
73- :param metatype: Given node metatype.
74- :param model: Target GraphModule instance.
75- :return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
76- """
77- if metatype in [om .PTEmbeddingMetatype ]:
78- weight_node = node .args [0 ]
79- if weight_node .op == "get_attr" :
80- return om .PTAtenEmbeddingMetatype
81-
82- return metatype
83-
8468 @staticmethod
8569 def get_node_type_and_metatype (node : torch .fx .Node , model : torch .fx .GraphModule ) -> Tuple [str , om .OperatorMetatype ]:
8670 """
@@ -118,7 +102,8 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
118102 layer_attrs = GraphConverter ._get_layer_attributes (node , node_metatype , model )
119103 node_subtype = node_metatype .determine_subtype (layer_attrs )
120104 node_metatype = node_subtype or node_metatype
121- return node_type , node_metatype
105+ node_type_name = node_type_name or node_type
106+ return node_type_name , node_metatype
122107
123108 @staticmethod
124109 def create_nncf_graph (model : torch .fx .GraphModule ) -> PTNNCFGraph :
@@ -135,7 +120,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
135120 const_targets_counter = Counter ([node .target for node in model .graph .nodes if node .op == "get_attr" ])
136121 for source_node in model .graph .nodes :
137122 node_type , node_metatype = GraphConverter .get_node_type_and_metatype (source_node , model )
138- node_metatype = GraphConverter ._map_fx_unique_metatypes (source_node , node_metatype )
139123 is_shared_node = source_node .op in ("get_attr" ,) and (
140124 const_targets_counter [source_node .target ] > 1 or len (source_node .users ) > 1
141125 )
0 commit comments