Skip to content

Commit e8c545a

Browse files
author
Bram Wasti
committed
Add external calling functionality
1 parent aacebaf commit e8c545a

22 files changed

+763
-238
lines changed

caffe2/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
466466
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp
467467
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp
468468
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_jit.cpp
469+
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/native.cpp
469470
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/types.cpp
470471
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp
471472
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp

test/test_tensorexpr.py

+19
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,25 @@ def np_easy(x, y, z):
250250
npr = np_easy(a.numpy(), b.numpy(), c.numpy())
251251
np.testing.assert_allclose(npr, x.numpy())
252252

253+
def test_matmul(self):
254+
llvm = LLVMCodeGenExecuted()
255+
def easy(x, y):
256+
aaa, bbb = torch.chunk(y, 2)
257+
y = torch.cat([aaa, bbb], dim=0)
258+
aaa = torch.matmul(x, y) * 3
259+
return aaa
260+
261+
shape = (128,128)
262+
a = torch.rand(shape)
263+
b = torch.rand(shape)
264+
traced = torch.jit.trace(
265+
easy, (a, b)
266+
)
267+
268+
x = traced(a, b)
269+
y = 3 * (a @ b)
270+
np.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-5, atol=1e-3)
271+
assert llvm.elapsed_value() == 1
253272

254273
def test_broadcast(self):
255274
def easy(x, y, z):

torch/csrc/jit/passes/guard_elimination.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#include <torch/csrc/jit/passes/guard_elimination.h>
21
#include <torch/csrc/jit/graph_executor.h>
32
#include <torch/csrc/jit/jit_log.h>
43
#include <torch/csrc/jit/passes/alias_analysis.h>
54
#include <torch/csrc/jit/passes/constant_propagation.h>
5+
#include <torch/csrc/jit/passes/guard_elimination.h>
66
#include <torch/csrc/jit/passes/peephole.h>
77
#include <memory>
88
#include <unordered_set>
@@ -243,6 +243,7 @@ struct GuardElimination {
243243
case aten::rsqrt:
244244
case aten::remainder:
245245
case aten::mm:
246+
case aten::matmul:
246247
case aten::min:
247248
case aten::max:
248249
case aten::type_as:

torch/csrc/jit/passes/tensorexpr_fuser.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <torch/csrc/jit/passes/dead_code_elimination.h>
99
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
1010
#include <torch/csrc/jit/tensorexpr/kernel.h>
11+
#include <torch/csrc/jit/tensorexpr/native.h>
1112

1213
using namespace torch::jit;
1314
using namespace torch::jit::tensorexpr;
@@ -119,7 +120,12 @@ bool isSupported(Node* node) {
119120
case aten::__rshift__:
120121
case aten::where:
121122
return true;
122-
default:
123+
default: {
124+
auto& nfr = getNativeFunctionRegistry();
125+
if (nfr.count(node->kind().toQualString())) {
126+
return true;
127+
}
128+
}
123129
return false;
124130
}
125131
}
@@ -140,10 +146,7 @@ bool canHandle(Node* node, AliasDb& aliasDb) {
140146
return false; \
141147
}
142148

143-
bool canMerge(
144-
Node* consumer,
145-
Node* producer,
146-
AliasDb& aliasDb) {
149+
bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) {
147150
// Only handle complete tensor types
148151
for (torch::jit::Value* output : consumer->outputs()) {
149152
REQ(output->isCompleteTensor());
@@ -162,8 +165,7 @@ bool canMerge(
162165
REQ(aliasDb.couldMoveAfterTopologically(consumer, producer));
163166

164167
// Ops that return aliases can only be folded if this is the only use.
165-
if (producer->kind() == aten::slice ||
166-
producer->kind() == aten::unsqueeze ||
168+
if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze ||
167169
producer->kind() == prim::ConstantChunk) {
168170
for (auto& use : producer->output(0)->uses()) {
169171
REQ(use.user == consumer);
@@ -196,11 +198,12 @@ bool canMerge(
196198
}
197199
#undef REQ
198200

199-
Node *getOrCreateTensorExprSubgraph(Node *n) {
201+
Node* getOrCreateTensorExprSubgraph(Node* n) {
200202
if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) {
201203
return n;
202204
}
203-
auto te_group = SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
205+
auto te_group =
206+
SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
204207
GRAPH_UPDATE("getOrCreateTensorExprSubgraph: ", *te_group);
205208
return te_group;
206209
}

torch/csrc/jit/tensorexpr/function.cpp

+30-7
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ Tensor* Compute(
4545
std::vector<const Var*> args;
4646
unpack_dim_args(dim_args, &dims, &args);
4747
const Expr* body = body_func(VarHandle(args[0])).node();
48-
Function* func =
49-
new Function(func_name, std::move(dims), std::move(args), std::move(body));
48+
Function* func = new Function(
49+
func_name, std::move(dims), std::move(args), std::move(body));
5050
return new Tensor(func, 0);
5151
}
5252

@@ -67,12 +67,16 @@ Tensor* Compute(
6767
Tensor* Compute(
6868
const std::string& func_name,
6969
const std::vector<DimArg>& dim_args,
70-
std::function<ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)> body_func) {
70+
std::function<
71+
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>
72+
body_func) {
7173
CHECK_EQ(dim_args.size(), 3ULL);
7274
std::vector<const Expr*> dims;
7375
std::vector<const Var*> args;
7476
unpack_dim_args(dim_args, &dims, &args);
75-
const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])).node();
77+
const Expr* body =
78+
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
79+
.node();
7680
Function* func = new Function(
7781
func_name, std::move(dims), std::move(args), std::move(body));
7882
return new Tensor(func, 0);
@@ -81,8 +85,11 @@ Tensor* Compute(
8185
Tensor* Compute(
8286
const std::string& func_name,
8387
const std::vector<DimArg>& dim_args,
84-
std::function<ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&, const VarHandle&)>
85-
body_func) {
88+
std::function<ExprHandle(
89+
const VarHandle&,
90+
const VarHandle&,
91+
const VarHandle&,
92+
const VarHandle&)> body_func) {
8693
CHECK_EQ(dim_args.size(), 4ULL);
8794
std::vector<const Expr*> dims;
8895
std::vector<const Var*> args_nodes;
@@ -96,6 +103,21 @@ Tensor* Compute(
96103

97104
Stmt* Function::ElementStmt(size_t index) {
98105
std::vector<ExprHandle> strides(dims_.size());
106+
auto* ce = dynamic_cast<const CallExternal*>(body(index));
107+
if (ce != nullptr) {
108+
std::vector<const Var*> input_vars;
109+
std::vector<const Expr*> input_args;
110+
for (auto p : ce->params()) {
111+
auto fc = dynamic_cast<const FunctionCall*>(p);
112+
if (fc) {
113+
input_vars.emplace_back(fc->tensor()->function()->func_var(index));
114+
} else {
115+
input_args.emplace_back(p);
116+
}
117+
}
118+
return OpaqueCall::make(
119+
ce->name(), func_var(index), input_vars, input_args);
120+
}
99121
for (size_t i = 0; i < strides.size(); i++) {
100122
if (i == strides.size() - 1) {
101123
strides[i] = ExprHandle(1);
@@ -120,7 +142,8 @@ Stmt* Function::ElementStmt(size_t index) {
120142

121143
const Expr* mask = new IntImm(1);
122144

123-
Stmt* update_stmt = new Store(func_var(index), total_index.node(), body(index), mask);
145+
Stmt* update_stmt =
146+
new Store(func_var(index), total_index.node(), body(index), mask);
124147
return update_stmt;
125148
}
126149

torch/csrc/jit/tensorexpr/ir.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ Dtype Intrinsics::IntrinsicsDtype(
5959
return params[0]->dtype();
6060
}
6161

62+
Dtype CallExternal::CallExternalDtype(
63+
std::string name,
64+
const std::vector<const Expr*>& params) {
65+
// TODO: check the op_type an dmake a real decision
66+
CHECK_GE(params.size(), 1ULL);
67+
return params[0]->dtype();
68+
}
69+
6270
int Intrinsics::OpArgCount(IntrinsicsOp op_type) {
6371
switch (op_type) {
6472
case kSin:
@@ -100,39 +108,42 @@ int Intrinsics::OpArgCount(IntrinsicsOp op_type) {
100108
}
101109
}
102110

103-
std::vector<const Expr*> ExprHandleVectorToExprVector(const std::vector<ExprHandle>& v) {
111+
std::vector<const Expr*> ExprHandleVectorToExprVector(
112+
const std::vector<ExprHandle>& v) {
104113
std::vector<const Expr*> result(v.size());
105114
for (size_t i = 0; i < v.size(); i++) {
106115
result[i] = v[i].node();
107116
}
108117
return result;
109118
}
110119

111-
std::vector<ExprHandle> ExprVectorToExprHandleVector(const std::vector<const Expr*>& v) {
120+
std::vector<ExprHandle> ExprVectorToExprHandleVector(
121+
const std::vector<const Expr*>& v) {
112122
std::vector<ExprHandle> result(v.size());
113123
for (size_t i = 0; i < v.size(); i++) {
114124
result[i] = ExprHandle(v[i]);
115125
}
116126
return result;
117127
}
118128

119-
std::vector<const Var*> VarHandleVectorToVarVector(const std::vector<VarHandle>& v) {
129+
std::vector<const Var*> VarHandleVectorToVarVector(
130+
const std::vector<VarHandle>& v) {
120131
std::vector<const Var*> result(v.size());
121132
for (size_t i = 0; i < v.size(); i++) {
122133
result[i] = v[i].node();
123134
}
124135
return result;
125136
}
126137

127-
std::vector<VarHandle> VarVectorToVarHandleVector(const std::vector<const Var*>& v) {
138+
std::vector<VarHandle> VarVectorToVarHandleVector(
139+
const std::vector<const Var*>& v) {
128140
std::vector<VarHandle> result(v.size());
129141
for (size_t i = 0; i < v.size(); i++) {
130142
result[i] = VarHandle(v[i]);
131143
}
132144
return result;
133145
}
134146

135-
136147
} // namespace tensorexpr
137148
} // namespace jit
138149
} // namespace torch

0 commit comments

Comments
 (0)