@@ -298,7 +298,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
298298 return last_out_var;
299299}
300300
301- static int BuildFusion (Graph* graph, const std::string& name_scope) {
301+ static int BuildFusion (Graph* graph, const std::string& name_scope,
302+ const SquaredMatSubFusePass* pass) {
302303 GraphPatternDetector gpd;
303304 auto * pattern = gpd.mutable_pattern ();
304305
@@ -320,6 +321,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
320321 auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
321322 Graph* g) {
322323 LOG (INFO) << " handle sqaure mat sub fuse" ;
324+ if (!pass->IsAcceptable (subgraph, g)) {
325+ LOG (WARNING) << " Pass in op compat failed." ;
326+ return ;
327+ }
328+
323329 auto & fused_pattern = gpd.pattern ();
324330
325331 auto * matx = retrieve_node (name_scope + " /x" , subgraph, fused_pattern);
@@ -368,14 +374,109 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
368374 GraphSafeRemoveNodes (graph, marked_nodes);
369375 ++fusion_count;
370376 };
371-
372377 gpd (graph, handler);
373378 return fusion_count;
374379}
375380
381+ SquaredMatSubFusePass::SquaredMatSubFusePass () {
382+ AddOpCompat (OpCompat (" square" ))
383+ .AddInput (" X" )
384+ .IsTensor ()
385+ .End ()
386+ .AddOutput (" Out" )
387+ .IsTensor ()
388+ .End ();
389+
390+ AddOpCompat (OpCompat (" matmul" ))
391+ .AddInput (" X" )
392+ .IsTensor ()
393+ .End ()
394+ .AddInput (" Y" )
395+ .IsTensor ()
396+ .End ()
397+ .AddOutput (" Out" )
398+ .IsTensor ()
399+ .End ()
400+ .AddAttr (" alpha" )
401+ .IsNumGE (0 .99f )
402+ .IsNumLE (1 .01f )
403+ .End ()
404+ .AddAttr (" transpose_X" )
405+ .IsBoolEQ (false )
406+ .End ()
407+ .AddAttr (" transpose_Y" )
408+ .IsBoolEQ (false )
409+ .End ();
410+
411+ AddOpCompat (OpCompat (" matmul_v2" ))
412+ .AddInput (" X" )
413+ .IsTensor ()
414+ .End ()
415+ .AddInput (" Y" )
416+ .IsTensor ()
417+ .End ()
418+ .AddOutput (" Out" )
419+ .IsTensor ()
420+ .End ()
421+ .AddAttr (" trans_x" )
422+ .IsBoolEQ (false )
423+ .End ()
424+ .AddAttr (" trans_y" )
425+ .IsBoolEQ (false )
426+ .End ();
427+
428+ AddOpCompat (OpCompat (" elementwise_sub" ))
429+ .AddInput (" X" )
430+ .IsTensor ()
431+ .End ()
432+ .AddInput (" Y" )
433+ .IsTensor ()
434+ .End ()
435+ .AddOutput (" Out" )
436+ .IsTensor ()
437+ .End ()
438+ .AddAttr (" axis" )
439+ .IsNumEQ (-1 )
440+ .End ();
441+
442+ AddOpCompat (OpCompat (" elementwise_mul" ))
443+ .AddInput (" X" )
444+ .IsTensor ()
445+ .End ()
446+ .AddInput (" Y" )
447+ .IsTensor ()
448+ .End ()
449+ .AddOutput (" Out" )
450+ .IsTensor ()
451+ .End ()
452+ .AddAttr (" axis" )
453+ .IsNumEQ (-1 )
454+ .End ();
455+
456+ AddOpCompat (OpCompat (" fill_constant" ))
457+ .AddOutput (" Out" )
458+ .IsTensor ()
459+ .End ()
460+ .AddAttr (" dtype" )
461+ .IsNumGE (0 )
462+ .IsNumLE (25 )
463+ .End ()
464+ .AddAttr (" shape" )
465+ .End ()
466+ // type:float,there is no restriction
467+ .AddAttr (" value" )
468+ .End ();
469+ }
470+
471+ // to use IsCompat
472+ bool SquaredMatSubFusePass::IsAcceptable (
473+ const GraphPatternDetector::subgraph_t & subgraph, Graph* g) const {
474+ return IsCompat (subgraph, g);
475+ }
476+
376477void SquaredMatSubFusePass::ApplyImpl (ir::Graph* graph) const {
377478 FusePassBase::Init (name_scope_, graph);
378- int fusion_count = BuildFusion (graph, name_scope_);
479+ int fusion_count = BuildFusion (graph, name_scope_, this );
379480 AddStatis (fusion_count);
380481}
381482
0 commit comments