Skip to content

Commit a59d776

Browse files
authoredJul 4, 2022
skip quant matmul in mha (#1232)
1 parent 171f5cf commit a59d776

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed
 

‎paddleslim/auto_compression/compressor.py

+11
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,17 @@ def _get_model_type(self, exe, model_dir, model_filename, params_filename):
313313
model_filename=model_filename, params_filename=params_filename,
314314
executor=exe))
315315
_, _, model_type = get_patterns(inference_program)
316+
if self.model_filename is None:
317+
new_model_filename = '__new_model__'
318+
else:
319+
new_model_filename = 'new_' + self.model_filename
320+
program_bytes = inference_program._remove_training_info(
321+
clip_extra=False).desc.serialize_to_string()
322+
with open(os.path.join(self.model_dir, new_model_filename), "wb") as f:
323+
f.write(program_bytes)
324+
shutil.move(
325+
os.path.join(self.model_dir, new_model_filename),
326+
os.path.join(self.model_dir, self.model_filename))
316327
_logger.info(f"Detect model type: {model_type}")
317328
return model_type
318329

‎paddleslim/common/patterns.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def find_final_nodes(program):
4343
return final_nodes
4444

4545

46-
def _is_mha(pattern_ops, pattern_ops_type):
46+
def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]):
4747
""" judge whether this pattern is multihead attention """
4848
if pattern_ops_type.count('softmax') != 1 or pattern_ops_type.count(
4949
'fetch') > 0:
@@ -53,6 +53,7 @@ def _is_mha(pattern_ops, pattern_ops_type):
5353
for op in pattern_ops:
5454
if op.type() in ['matmul', 'matmul_v2']:
5555
if not is_dynamic_weight_op(op):
56+
skip_quant_tensor_list.extend(op._op.input('X'))
5657
matmul_num += 1
5758
if matmul_num == 2:
5859
return True
@@ -81,6 +82,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
8182
def get_patterns(program, only_final_node=True):
8283
""" distinguish the pattern in the program and get distillation node """
8384
distill_node = []
85+
skip_quant_tensor_list = []
8486
patterns = {}
8587
graph = GraphWrapper(program)
8688
block_num = 0
@@ -110,7 +112,8 @@ def get_patterns(program, only_final_node=True):
110112
pattern_name = shortcut_start_op.type() + '$' + str(op.idx(
111113
))
112114

113-
if _is_mha(pattern_ops, pattern_ops_type):
115+
if _is_mha(pattern_ops, pattern_ops_type,
116+
skip_quant_tensor_list):
114117
model_type = 'transformer'
115118
pattern_name = 'MHA$' + str(block_num)
116119

@@ -145,4 +148,12 @@ def get_patterns(program, only_final_node=True):
145148
distill_node.append('teacher_' + out_var.name())
146149
distill_node.append(out_var.name())
147150

151+
#### skip quant matmul in attention
152+
if model_type == 'transformer':
153+
for block_id in range(len(program.blocks)):
154+
for op in program.blocks[block_id].ops:
155+
for inp_name in op.input_arg_names:
156+
if inp_name in skip_quant_tensor_list:
157+
op._set_attr("op_namescope", "skip_quant")
158+
148159
return patterns, distill_node, model_type

0 commit comments

Comments
 (0)
Please sign in to comment.