@@ -124,11 +124,8 @@ def _transformLayoutPermutation(dims: int, spatialDims: int, targetChannelsFirst
124124
125125
126126# Calculate permutation q = p^(-1) s.t. q(p(i)) = i
127- def _invertPermutation (permutation : List [int ]) -> List [int ]:
128- inverse = [0 ] * len (permutation )
129- for idx , permIdx in enumerate (permutation ):
130- inverse [permIdx ] = idx
131- return inverse
127+ def _invertPermutation (permutation : Sequence [int ]) -> List [int ]:
128+ return [permutation .index (i ) for i in range (len (permutation ))]
132129
133130
134131T = TypeVar ('T' )
@@ -283,31 +280,10 @@ def __init__(self, default_channels_first: bool = True):
283280 super ().__init__ (graph , partial (_NCHWtoNHWC_fun , default_channels_first = default_channels_first ), name )
284281
285282
286- @contextagnostic
287- class NCHWtoNHWCPass (SequentialPass ):
288-
289- def __init__ (self , default_channels_first : bool = True ):
290- passes = [
291- NCHWtoNHWCPadPass (default_channels_first ),
292- NCHWtoNHWCMaxPoolPass (default_channels_first ),
293- NCHWtoNHWCConvPass (default_channels_first ),
294- NCHWtoNHWCRequantizedConvPass (default_channels_first ),
295- ]
296- super ().__init__ (* passes )
297-
298-
299- def _PULPDWNCHWtoNHWC_fun (graph : gs .Graph , match : Match , name : str , default_channels_first : bool = True ):
300-
301- matched_nodes = [m for k , m in match .nodes_map .items ()]
302- opNode = matched_nodes [0 ]
303- node_op = opNode .op
304-
305- if 'group' in opNode .attrs and opNode .attrs ['group' ] == 1 :
306283def _NCWHtoNHWC_dw_fun (graph : gs .Graph , match : Match , name : str , default_channels_first : bool ) -> gs .Graph :
307284 node = next (iter ((match .nodes_map .values ())))
308285
309286 if not _isDepthwise (node ):
310- if opNode .attrs .get ('group' , 1 ) == 1 :
311287 return graph
312288
313289 channels_first = node .attrs .get ("channels_first" , True )
@@ -340,47 +316,10 @@ def _NCWHtoNHWC_dw_fun(graph: gs.Graph, match: Match, name: str, default_channel
340316class NCHWtoNHWCDwConvPass (ReplaceSequentialPatternPass ):
341317
342318 def __init__ (self , default_channels_first : bool = True ):
343- # Define pattern graph
344- graph = gs .Graph ()
345-
346- _input = gs .Variable (name = 'input_1' )
347- output = graph .layer (inputs = [_input ], outputs = ['convOut' ], op = 'RequantizedConv' , name = 'requantizedConv' )
348-
349- graph .outputs .append (output )
350- graph .inputs .append (_input )
351-
352- # Define name
319+ graph = _singleNodePattern (op = "Conv|RequantizedConv" )
353320 name = "_NCHW_TO_NHWC_DW_CONV_PASS"
354-
355- # Initialize Pass
356- super ().__init__ (pattern = graph ,
357- replacement_fn = partial (_PULPDWNCHWtoNHWC_fun ,
358- default_channels_first = default_channels_first ),
359- name = name )
360-
361-
362- # Float DW Conv
363- @contextagnostic
364- class PULPFPDWConvPass (ReplaceSequentialPatternPass ):
365-
366- def __init__ (self , default_channels_first : bool = True ):
367- # Define pattern graph
368- graph = gs .Graph ()
369-
370- _input = gs .Variable (name = 'input_1' )
371- output = graph .layer (inputs = [_input ], outputs = ['convOut' ], op = 'Conv' , name = 'conv' )
372-
373- graph .outputs .append (output )
374- graph .inputs .append (_input )
375-
376- # Define name
377- name = "_NCHW_TO_NHWC_FP_DW_CONV_PASS"
378-
379- # Initialize Pass
380- super ().__init__ (pattern = graph ,
381- replacement_fn = partial (_PULPDWNCHWtoNHWC_fun ,
382- default_channels_first = default_channels_first ),
383- name = name )
321+ super ().__init__ (graph , partial (_NCWHtoNHWC_dw_fun , default_channels_first = default_channels_first ), name ,
322+ NonBranchingMatcher (regex_op = True ))
384323
385324
386325def _PULP_NCHWtoNHWC_dw_fun (graph : gs .Graph , match : Match , name : str , default_channels_first : bool = True ):
@@ -425,10 +364,8 @@ def __init__(self, default_channels_first: bool = True):
425364 passes = [
426365 NCHWtoNHWCPadPass (default_channels_first ),
427366 NCHWtoNHWCMaxPoolPass (default_channels_first ),
428- PULPDWConvPass (default_channels_first ),
429- PULPFPDWConvPass (default_channels_first ),
430- PULPNCHWtoNHWCDenseConvPass (default_channels_first ),
431- PULPNCHWtoNHWCDenseRequantizedConvPass (default_channels_first ),
367+ NCHWtoNHWCDwConvPass (default_channels_first ),
368+ NCHWtoNHWCConvPass (default_channels_first ),
432369 ]
433370 super ().__init__ (* passes )
434371
@@ -494,11 +431,6 @@ def _requantized_gemm_to_pw_fun(graph: gs.Graph, match: Match, name: str):
494431 matrixAExpandDimsNode , pwIn = _appendExpandDims (matrixA , name , axis = expandAxis )
495432 graph .nodes .append (matrixAExpandDimsNode )
496433
497- # If transB is set then the matrix is of shape [N x K] and it doesn't need to be transposed, otherwise its shape is [K x N] and it has to be transposed
498- if not 'transB' in requantizedGemm .attrs or requantizedGemm .attrs ['transB' ] == 0 :
499- # matrixBTransposed, shape [N x K]
500- matrixBTransposeNode , matrixB = _appendTransposeNode (matrixB , name , _permutationLastTwoDims (len (matrixB .shape )))
501- graph .nodes .append (matrixBTransposeNode )
502434 # pwWeight, shape [N x 1 x 1 x K]
503435 matrixBExpandDimsNode , pwWeight = _appendExpandDims (matrixB , name , axis = (1 , 2 ))
504436 graph .nodes .append (matrixBExpandDimsNode )
0 commit comments