@@ -43,7 +43,7 @@ def find_final_nodes(program):
43
43
return final_nodes
44
44
45
45
46
- def _is_mha (pattern_ops , pattern_ops_type ):
46
+ def _is_mha (pattern_ops , pattern_ops_type , skip_quant_tensor_list = [] ):
47
47
""" judge whether this pattern is multihead attention """
48
48
if pattern_ops_type .count ('softmax' ) != 1 or pattern_ops_type .count (
49
49
'fetch' ) > 0 :
@@ -53,6 +53,7 @@ def _is_mha(pattern_ops, pattern_ops_type):
53
53
for op in pattern_ops :
54
54
if op .type () in ['matmul' , 'matmul_v2' ]:
55
55
if not is_dynamic_weight_op (op ):
56
+ skip_quant_tensor_list .extend (op ._op .input ('X' ))
56
57
matmul_num += 1
57
58
if matmul_num == 2 :
58
59
return True
@@ -81,6 +82,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
81
82
def get_patterns (program , only_final_node = True ):
82
83
""" distinguish the pattern in the program and get distillation node """
83
84
distill_node = []
85
+ skip_quant_tensor_list = []
84
86
patterns = {}
85
87
graph = GraphWrapper (program )
86
88
block_num = 0
@@ -110,7 +112,8 @@ def get_patterns(program, only_final_node=True):
110
112
pattern_name = shortcut_start_op .type () + '$' + str (op .idx (
111
113
))
112
114
113
- if _is_mha (pattern_ops , pattern_ops_type ):
115
+ if _is_mha (pattern_ops , pattern_ops_type ,
116
+ skip_quant_tensor_list ):
114
117
model_type = 'transformer'
115
118
pattern_name = 'MHA$' + str (block_num )
116
119
@@ -145,4 +148,12 @@ def get_patterns(program, only_final_node=True):
145
148
distill_node .append ('teacher_' + out_var .name ())
146
149
distill_node .append (out_var .name ())
147
150
151
+ #### skip quant matmul in attention
152
+ if model_type == 'transformer' :
153
+ for block_id in range (len (program .blocks )):
154
+ for op in program .blocks [block_id ].ops :
155
+ for inp_name in op .input_arg_names :
156
+ if inp_name in skip_quant_tensor_list :
157
+ op ._set_attr ("op_namescope" , "skip_quant" )
158
+
148
159
return patterns , distill_node , model_type
0 commit comments