@@ -65,22 +65,6 @@ def _get_layer_attributes(
6565 )
6666 return None
6767
68- def _map_fx_unique_metatypes (node : torch .fx .Node , metatype : om .OperatorMetatype ) -> om .OperatorMetatype :
69- """
70- Attempts to retrieve correct subtype for the given node.
71-
72- :param node: Given node.
73- :param metatype: Given node metatype.
74- :param model: Target GraphModule instance.
75- :return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
76- """
77- if metatype in [om .PTEmbeddingMetatype ]:
78- weight_node = node .args [0 ]
79- if weight_node .op == "get_attr" :
80- return om .PTAtenEmbeddingMetatype
81-
82- return metatype
83-
8468 @staticmethod
8569 def get_node_type_and_metatype (node : torch .fx .Node , model : torch .fx .GraphModule ) -> Tuple [str , om .OperatorMetatype ]:
8670 """
@@ -121,8 +105,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
121105 layer_attrs = GraphConverter ._get_layer_attributes (node , node_metatype , model )
122106 node_subtype = node_metatype .determine_subtype (layer_attrs )
123107 node_metatype = node_subtype or node_metatype
124- if not node_type_name :
125- node_type_name = node_type
108+ node_type_name = node_type_name or node_type
126109 return node_type_name , node_metatype
127110
128111 @staticmethod
@@ -140,7 +123,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
140123 const_targets_counter = Counter ([node .target for node in model .graph .nodes if node .op == "get_attr" ])
141124 for source_node in model .graph .nodes :
142125 node_type , node_metatype = GraphConverter .get_node_type_and_metatype (source_node , model )
143- node_metatype = GraphConverter ._map_fx_unique_metatypes (source_node , node_metatype )
144126 is_shared_node = source_node .op in ("get_attr" ,) and (
145127 const_targets_counter [source_node .target ] > 1 or len (source_node .users ) > 1
146128 )
0 commit comments