Skip to content

Commit 561bcc3

Browse files
committed
[frontend] Implement SiLU fusion with mul + sigmoid pattern
1 parent 5c8c88f commit 561bcc3

File tree

3 files changed

+104
-14
lines changed

3 files changed

+104
-14
lines changed

frontend/Python/graph/transform/fuse_ops.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from .. import DeviceType
2424
from torch.fx.immutable_collections import immutable_list
2525

26-
classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp}
26+
classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp,
27+
"silu_fusion": SiluOp}
2728

2829
# TODO: classify op type for op fusion
2930
# OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType]
@@ -52,10 +53,25 @@ def classic_fuse_check(graph: Graph):
5253
1
5354
] == immutable_list([1, 0]):
5455
pattern = target, parentop, "transpose_matmul_fusion"
56+
elif isinstance(op, MulOp):
57+
# Check for mul + sigmoid fusion pattern: mul(x, sigmoid(x))
58+
parentop = [graph.node_table[str(i)] for i in op._parents]
59+
for target in parentop:
60+
if isinstance(target, SigmoidOp):
61+
# Check if the sigmoid input is also an input to the mul operation
62+
sigmoid_input = target._parents[0] if target._parents else None
63+
if sigmoid_input and sigmoid_input in op._parents:
64+
pattern = target, parentop, "silu_fusion"
65+
break
5566
if pattern:
56-
transpose_matmul_fusion(
57-
graph, op, pattern[0], pattern[1], pattern[2]
58-
)
67+
if pattern[2] == "transpose_matmul_fusion":
68+
transpose_matmul_fusion(
69+
graph, op, pattern[0], pattern[1], pattern[2]
70+
)
71+
elif pattern[2] == "silu_fusion":
72+
silu_fusion(
73+
graph, op, pattern[0], pattern[1], pattern[2]
74+
)
5975

6076

6177
def transpose_matmul_fusion(
@@ -91,6 +107,44 @@ def transpose_matmul_fusion(
91107
graph.delete_node(target, targets_parent)
92108

93109

110+
def silu_fusion(
111+
graph: Graph, node, target: Op, parents: List[Op], pattern: str
112+
):
113+
"""
114+
Function to fuse mul and sigmoid operations into one operation.
115+
Such as mul(x, sigmoid(x)) -> fused_mul_sigmoid(x)
116+
117+
Args:
118+
- graph (Graph): The input graph to be simplified.
119+
- node (Op): The mul operation to be fused.
120+
- target (Op): The sigmoid operation to be fused.
121+
- parents (List[Op]): The parents of the node to be fused.
122+
- pattern (str): The pattern of the fusion.
123+
Returns:
124+
- None: Modifies the input graph in place.
125+
"""
126+
fused_op = classicfuse_register.get(pattern)()
127+
# mulop -> fusedmulopnode
128+
fused_op.name = "fused" + node.name
129+
graph.displace_node(node, fused_op)
130+
fused_op.args.pop(fused_op.args.index(target.name))
131+
fused_op._parents.pop(fused_op._parents.index(target.name))
132+
fused_op.args.extend(target.args)
133+
134+
fused_op._parents.extend(target._parents)
135+
136+
fused_op.args[:] = list(set(fused_op.args))
137+
fused_op._parents[:] = list(set(fused_op._parents))
138+
139+
targets_parent = [graph.node_table[i] for i in target._parents]
140+
for i in targets_parent:
141+
i.add_children(fused_op.name)
142+
target._children.pop(target._children.index(fused_op.name))
143+
144+
if graph.check_delete_node(target):
145+
graph.delete_node(target, targets_parent)
146+
147+
94148
def apply_classic_fusion(graph: Graph):
95149
"""
96150
Function to fuse some typical operations into one operation and fuse
@@ -134,3 +188,4 @@ def simply_fuse(graph: Graph):
134188
graph.op_groups = {}
135189
graph.op_groups["subgraph0"] = new_op_group
136190
graph.group_map_device = {"subgraph0": device}
191+

frontend/Python/ops/tosa.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
ArgMaxOp,
6565
ScaledDotProductFlashAttentionForCpuOp,
6666
MatmulOp,
67+
SiluOp,
6768
)
6869
from .utils import *
6970

@@ -1449,6 +1450,26 @@ def sigmoid_op(node: SigmoidOp, symbol_table):
14491450
return op
14501451

14511452

1453+
def silu_op(node: SiluOp, symbol_table):
1454+
"""
1455+
Import the buddy SiluOp.
1456+
Implements SiLU fusion: x * sigmoid(x) using tosa.sigmoid and tosa.mul.
1457+
"""
1458+
input_tensor = symbol_table.get((str(node.args[0]), 0))
1459+
if input_tensor is None:
1460+
return
1461+
1462+
output_shape = list(node.tensor_meta["shape"])
1463+
dtype = node.tensor_meta["dtype"]
1464+
mlir_dtype = mlir_element_type_get(dtype)
1465+
tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
1466+
1467+
sigmoid_op = tosa.SigmoidOp(tensor_type, input_tensor)
1468+
mul_op = tosa.MulOp(tensor_type, input_tensor, sigmoid_op.result)
1469+
1470+
return mul_op
1471+
1472+
14521473
def reciprocal_op(node: ReciprocalOp, symbol_table):
14531474
"""
14541475
Import the buddy ReciprocalOp.
@@ -1859,6 +1880,7 @@ def scaled_dot_product_flash_attention_for_cpu_op(
18591880
"ReluOp": relu_op,
18601881
"IotaOp": iota_op,
18611882
"SigmoidOp": sigmoid_op,
1883+
"SiLUOp": silu_op,
18621884
"ReciprocalOp": reciprocal_op,
18631885
"MeanOp": mean_op,
18641886
"ClampMinOp": clamp_min_op,

tests/Python/test_silu.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,42 @@
77

88
from buddy.compiler.frontend import DynamoCompiler
99
from buddy.compiler.ops import linalg
10+
from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion
1011

12+
def silu_pattern(x):
13+
sigmoid_x = torch.sigmoid(x)
14+
return torch.mul(x, sigmoid_x)
1115

1216
def foo(x):
13-
return torch.nn.functional.silu(x)
17+
return silu_pattern(x)
1418

15-
16-
in1 = torch.ones([13, 13], dtype=torch.float32)
19+
x = torch.ones([4, 4], dtype=torch.float32)
1720
# Initialize the dynamo compiler.
1821
dynamo_compiler = DynamoCompiler(
1922
primary_registry=linalg.ops_registry,
2023
aot_autograd_decomposition=aot_autograd_decompositions,
2124
)
2225

23-
graphs = dynamo_compiler.importer(foo, in1)
26+
graphs = dynamo_compiler.importer(foo, x)
2427
assert len(graphs) == 1
2528
graph = graphs[0]
29+
pattern_list = [apply_classic_fusion]
30+
graphs[0].fuse_ops(pattern_list)
31+
2632
graph.lower_to_top_level_ir()
2733
print(graph._imported_module)
2834

29-
# CHECK: module {
35+
# CHECK: module {
3036
# CHECK-LABEL: func.func @forward
31-
# CHECK: %{{.*}} = tensor.empty
32-
# CHECK: %{{.*}} = linalg.generic
33-
# CHECK: return %{{.*}}
34-
# CHECK: }
35-
# CHECK: }
37+
# CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x4xf32>
38+
# CHECK: %[[RES:.*]] = linalg.generic {.*} ins(%arg0 : tensor<4x4xf32>) outs(%[[EMPTY]] : tensor<4x4xf32>) {
39+
# CHECK: ^bb0(%in: f32, %out: f32):
40+
# CHECK: %[[NEG:.*]] = arith.negf %in : f32
41+
# CHECK: %[[EXP:.*]] = math.exp %[[NEG]] : f32
42+
# CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
43+
# CHECK: %[[ADD:.*]] = arith.addf %[[EXP]], %[[ONE]] : f32
44+
# CHECK: %[[DIV:.*]] = arith.divf %in, %[[ADD]] : f32
45+
# CHECK: linalg.yield %[[DIV]] : f32
46+
# CHECK: } -> tensor<4x4xf32>
47+
# CHECK: return %[[RES]] : tensor<4x4xf32>
48+

0 commit comments

Comments
 (0)