@@ -45,8 +45,8 @@ Tensor* Compute(
45
45
std::vector<const Var*> args;
46
46
unpack_dim_args (dim_args, &dims, &args);
47
47
const Expr* body = body_func (VarHandle (args[0 ])).node ();
48
- Function* func =
49
- new Function ( func_name, std::move (dims), std::move (args), std::move (body));
48
+ Function* func = new Function (
49
+ func_name, std::move (dims), std::move (args), std::move (body));
50
50
return new Tensor (func, 0 );
51
51
}
52
52
@@ -67,12 +67,16 @@ Tensor* Compute(
67
67
Tensor* Compute (
68
68
const std::string& func_name,
69
69
const std::vector<DimArg>& dim_args,
70
- std::function<ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)> body_func) {
70
+ std::function<
71
+ ExprHandle (const VarHandle&, const VarHandle&, const VarHandle&)>
72
+ body_func) {
71
73
CHECK_EQ (dim_args.size (), 3ULL );
72
74
std::vector<const Expr*> dims;
73
75
std::vector<const Var*> args;
74
76
unpack_dim_args (dim_args, &dims, &args);
75
- const Expr* body = body_func (VarHandle (args[0 ]), VarHandle (args[1 ]), VarHandle (args[2 ])).node ();
77
+ const Expr* body =
78
+ body_func (VarHandle (args[0 ]), VarHandle (args[1 ]), VarHandle (args[2 ]))
79
+ .node ();
76
80
Function* func = new Function (
77
81
func_name, std::move (dims), std::move (args), std::move (body));
78
82
return new Tensor (func, 0 );
@@ -81,8 +85,11 @@ Tensor* Compute(
81
85
Tensor* Compute (
82
86
const std::string& func_name,
83
87
const std::vector<DimArg>& dim_args,
84
- std::function<ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&, const VarHandle&)>
85
- body_func) {
88
+ std::function<ExprHandle(
89
+ const VarHandle&,
90
+ const VarHandle&,
91
+ const VarHandle&,
92
+ const VarHandle&)> body_func) {
86
93
CHECK_EQ (dim_args.size (), 4ULL );
87
94
std::vector<const Expr*> dims;
88
95
std::vector<const Var*> args_nodes;
@@ -96,6 +103,21 @@ Tensor* Compute(
96
103
97
104
Stmt* Function::ElementStmt (size_t index) {
98
105
std::vector<ExprHandle> strides (dims_.size ());
106
+ auto * ce = dynamic_cast <const CallExternal*>(body (index ));
107
+ if (ce != nullptr ) {
108
+ std::vector<const Var*> input_vars;
109
+ std::vector<const Expr*> input_args;
110
+ for (auto p : ce->params ()) {
111
+ auto fc = dynamic_cast <const FunctionCall*>(p);
112
+ if (fc) {
113
+ input_vars.emplace_back (fc->tensor ()->function ()->func_var (index ));
114
+ } else {
115
+ input_args.emplace_back (p);
116
+ }
117
+ }
118
+ return OpaqueCall::make (
119
+ ce->name (), func_var (index ), input_vars, input_args);
120
+ }
99
121
for (size_t i = 0 ; i < strides.size (); i++) {
100
122
if (i == strides.size () - 1 ) {
101
123
strides[i] = ExprHandle (1 );
@@ -120,7 +142,8 @@ Stmt* Function::ElementStmt(size_t index) {
120
142
121
143
const Expr* mask = new IntImm (1 );
122
144
123
- Stmt* update_stmt = new Store (func_var (index ), total_index.node (), body (index ), mask);
145
+ Stmt* update_stmt =
146
+ new Store (func_var (index ), total_index.node (), body (index ), mask);
124
147
return update_stmt;
125
148
}
126
149
0 commit comments