2323from .. import DeviceType
2424from torch .fx .immutable_collections import immutable_list
2525
26- classicfuse_register = {"transpose_matmul_fusion" : TransposeMatmulFusedOp }
26+ classicfuse_register = {"transpose_matmul_fusion" : TransposeMatmulFusedOp ,
27+ "silu_fusion" : SiluOp }
2728
2829# TODO: classify op type for op fusion
2930# OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType]
@@ -52,10 +53,25 @@ def classic_fuse_check(graph: Graph):
5253 1
5354 ] == immutable_list ([1 , 0 ]):
5455 pattern = target , parentop , "transpose_matmul_fusion"
56+ elif isinstance (op , MulOp ):
57+ # Check for mul + sigmoid fusion pattern: mul(x, sigmoid(x))
58+ parentop = [graph .node_table [str (i )] for i in op ._parents ]
59+ for target in parentop :
60+ if isinstance (target , SigmoidOp ):
61+ # Check if the sigmoid input is also an input to the mul operation
62+ sigmoid_input = target ._parents [0 ] if target ._parents else None
63+ if sigmoid_input and sigmoid_input in op ._parents :
64+ pattern = target , parentop , "silu_fusion"
65+ break
5566 if pattern :
56- transpose_matmul_fusion (
57- graph , op , pattern [0 ], pattern [1 ], pattern [2 ]
58- )
67+ if pattern [2 ] == "transpose_matmul_fusion" :
68+ transpose_matmul_fusion (
69+ graph , op , pattern [0 ], pattern [1 ], pattern [2 ]
70+ )
71+ elif pattern [2 ] == "silu_fusion" :
72+ silu_fusion (
73+ graph , op , pattern [0 ], pattern [1 ], pattern [2 ]
74+ )
5975
6076
6177def transpose_matmul_fusion (
@@ -91,6 +107,44 @@ def transpose_matmul_fusion(
91107 graph .delete_node (target , targets_parent )
92108
93109
110+ def silu_fusion (
111+ graph : Graph , node , target : Op , parents : List [Op ], pattern : str
112+ ):
113+ """
114+ Function to fuse mul and sigmoid operations into one operation.
115+ Such as mul(x, sigmoid(x)) -> fused_mul_sigmoid(x)
116+
117+ Args:
118+ - graph (Graph): The input graph to be simplified.
119+ - node (Op): The mul operation to be fused.
120+ - target (Op): The sigmoid operation to be fused.
121+ - parents (List[Op]): The parents of the node to be fused.
122+ - pattern (str): The pattern of the fusion.
123+ Returns:
124+ - None: Modifies the input graph in place.
125+ """
126+ fused_op = classicfuse_register .get (pattern )()
127+ # mulop -> fusedmulopnode
128+ fused_op .name = "fused" + node .name
129+ graph .displace_node (node , fused_op )
130+ fused_op .args .pop (fused_op .args .index (target .name ))
131+ fused_op ._parents .pop (fused_op ._parents .index (target .name ))
132+ fused_op .args .extend (target .args )
133+
134+ fused_op ._parents .extend (target ._parents )
135+
136+ fused_op .args [:] = list (set (fused_op .args ))
137+ fused_op ._parents [:] = list (set (fused_op ._parents ))
138+
139+ targets_parent = [graph .node_table [i ] for i in target ._parents ]
140+ for i in targets_parent :
141+ i .add_children (fused_op .name )
142+ target ._children .pop (target ._children .index (fused_op .name ))
143+
144+ if graph .check_delete_node (target ):
145+ graph .delete_node (target , targets_parent )
146+
147+
94148def apply_classic_fusion (graph : Graph ):
95149 """
96150 Function to fuse some typical operations into one operation and fuse
@@ -134,3 +188,4 @@ def simply_fuse(graph: Graph):
134188 graph .op_groups = {}
135189 graph .op_groups ["subgraph0" ] = new_op_group
136190 graph .group_map_device = {"subgraph0" : device }
191+
0 commit comments