Skip to content

Commit 6628d0f

Browse files
authored
Reduce paren explosion in IR printer. (#246)
1 parent 65d2e29 commit 6628d0f

File tree

4 files changed

+129
-40
lines changed

4 files changed

+129
-40
lines changed

test/cpp/tensorexpr/test_ir_printer.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void testIRPrinterBasicValueTest() {
1818

1919
std::stringstream ss;
2020
ss << c;
21-
EXPECT_EQ(ss.str(), "(2 + 3)");
21+
EXPECT_EQ(ss.str(), "2 + 3");
2222
}
2323

2424
void testIRPrinterBasicValueTest02() {
@@ -31,7 +31,7 @@ void testIRPrinterBasicValueTest02() {
3131

3232
std::stringstream ss;
3333
ss << f;
34-
EXPECT_EQ(ss.str(), "((2.f + 3.f) - (4.f + 5.f))");
34+
EXPECT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
3535
}
3636

3737
void testIRPrinterLetTest01() {
@@ -43,7 +43,7 @@ void testIRPrinterLetTest01() {
4343

4444
std::stringstream ss;
4545
ss << result;
46-
EXPECT_EQ(ss.str(), "(let x = 3.f in (2.f + ((x * 3.f) + 4.f)))");
46+
EXPECT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
4747
}
4848

4949
void testIRPrinterLetTest02() {
@@ -58,7 +58,7 @@ void testIRPrinterLetTest02() {
5858
std::stringstream ss;
5959
ss << e2;
6060
EXPECT_EQ(
61-
ss.str(), "(let y = 6.f in (let x = 3.f in (2.f + ((x * 3.f) + (4.f * y)))))");
61+
ss.str(), "let y = 6.f in (let x = 3.f in 2.f + (x * 3.f + 4.f * y))");
6262
}
6363

6464
void testIRPrinterCastTest() {
@@ -74,7 +74,7 @@ void testIRPrinterCastTest() {
7474
ss << e2;
7575
EXPECT_EQ(
7676
ss.str(),
77-
"(let y = 6.f in (let x = int(3.f) in (2.f + ((x * 3.f) + (4.f * y)))))");
77+
"let y = 6.f in (let x = int(3.f) in 2.f + (x * 3.f + 4.f * y))");
7878
}
7979
} // namespace jit
8080
} // namespace torch

torch/csrc/jit/tensorexpr/expr.h

+28-2
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,44 @@ namespace torch {
1414
namespace jit {
1515
namespace tensorexpr {
1616

17+
enum IRNodeType {
18+
kPrimitive,
19+
kAdd,
20+
kSub,
21+
kMul,
22+
kDiv,
23+
kMod,
24+
kMax,
25+
kMin,
26+
kAnd,
27+
kOr,
28+
kLshift,
29+
kRshift,
30+
kXor,
31+
kCompareSelect,
32+
kLet,
33+
kCast,
34+
kNone
35+
};
36+
1737
// The common base between all expression node.
1838
class Expr : public KernelScopedObject {
1939
public:
20-
explicit Expr(Dtype dtype) : dtype_(dtype) {}
40+
explicit Expr(Dtype dtype, IRNodeType expr_type = kNone)
41+
: dtype_(dtype), expr_type_(expr_type) {}
2142
Dtype dtype() const {
2243
return dtype_;
2344
}
2445
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
2546
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;
2647

48+
IRNodeType expr_type() const {
49+
return expr_type_;
50+
}
51+
2752
private:
2853
Dtype dtype_;
54+
IRNodeType expr_type_;
2955
};
3056

3157
// A CRTP pattern to accept visitors for children class,
@@ -121,7 +147,7 @@ class Var : public ExprNode<Var> {
121147
}
122148

123149
Var(const std::string& name_hint, Dtype dtype)
124-
: ExprNodeBase(dtype), name_hint_(name_hint) {}
150+
: ExprNodeBase(dtype, kPrimitive), name_hint_(name_hint) {}
125151

126152
private:
127153
std::string name_hint_;

torch/csrc/jit/tensorexpr/ir.h

+40-26
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,6 @@ namespace torch {
1010
namespace jit {
1111
namespace tensorexpr {
1212

13-
enum IRNodeType {
14-
kAdd,
15-
kSub,
16-
kMul,
17-
kDiv,
18-
kMod,
19-
kMax,
20-
kMin,
21-
kAnd,
22-
kOr,
23-
kLshift,
24-
kRshift,
25-
kXor,
26-
kCompareSelect,
27-
};
28-
2913
enum CompareSelectOperation {
3014
kEQ,
3115
kGT,
@@ -35,6 +19,41 @@ enum CompareSelectOperation {
3519
kNE,
3620
};
3721

22+
inline int getPrecedence(IRNodeType ty) {
23+
// Match C++ operator precedence rules, since some pretty-print expressions to C++.
24+
// SEE: https://en.cppreference.com/w/cpp/language/operator_precedence
25+
switch (ty) {
26+
case kPrimitive:
27+
return 0;
28+
case kCast:
29+
return 2;
30+
case kAdd:
31+
case kSub:
32+
return 6;
33+
case kMul:
34+
case kDiv:
35+
case kMod:
36+
return 5;
37+
case kMax:
38+
case kMin:
39+
return 99;
40+
case kAnd:
41+
return 11;
42+
case kOr:
43+
return 13;
44+
case kLshift:
45+
case kRshift:
46+
return 7;
47+
case kXor:
48+
return 12;
49+
case kCompareSelect:
50+
case kLet:
51+
return 16;
52+
default:
53+
return 99;
54+
}
55+
}
56+
3857
class Buffer;
3958

4059
class Cast : public ExprNode<Cast> {
@@ -46,7 +65,7 @@ class Cast : public ExprNode<Cast> {
4665
return ExprHandle(new Cast(dtype, src_value.node()));
4766
}
4867
Cast(Dtype dtype, const Expr* src_value)
49-
: ExprNodeBase(dtype), src_value_(src_value) {}
68+
: ExprNodeBase(dtype, kCast), src_value_(src_value) {}
5069

5170
private:
5271
const Expr* src_value_;
@@ -68,9 +87,6 @@ class BinaryOpNode : public ExprNode<Op> {
6887
const Expr* rhs() const {
6988
return this->rhs_;
7089
}
71-
IRNodeType expr_type() const {
72-
return expr_type_;
73-
}
7490

7591
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
7692
return ExprHandle(new Op(lhs.node(), rhs.node()));
@@ -81,10 +97,9 @@ class BinaryOpNode : public ExprNode<Op> {
8197
const Expr* rhs_v,
8298
IRNodeType expr_type,
8399
ScalarType ret_type = ScalarType::None)
84-
: ExprNode<Op>(BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type)),
100+
: ExprNode<Op>(BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type), expr_type),
85101
lhs_(CastIfNeeded(lhs_v, ExprNode<Op>::dtype())),
86-
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())),
87-
expr_type_(expr_type) {}
102+
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())) { }
88103

89104
private:
90105
static const Expr* CastIfNeeded(const Expr* expr, Dtype dst_dtype) {
@@ -96,7 +111,6 @@ class BinaryOpNode : public ExprNode<Op> {
96111

97112
const Expr* lhs_;
98113
const Expr* rhs_;
99-
IRNodeType expr_type_;
100114
};
101115

102116
class Add : public BinaryOpNode<Add> {
@@ -216,7 +230,7 @@ class Min : public BinaryOpNode<Min> {
216230
#define IMM_DECLARE(Type, Name) \
217231
class Name##Imm : public ExprNode<Name##Imm> { \
218232
public: \
219-
Name##Imm(Type value) : ExprNodeBase(k##Name), value_(value) {} \
233+
Name##Imm(Type value) : ExprNodeBase(k##Name, kPrimitive), value_(value) {} \
220234
Type value() const { \
221235
return value_; \
222236
} \
@@ -248,7 +262,7 @@ class Let : public ExprNode<Let> {
248262
}
249263

250264
Let(const Expr* var, const Expr* value, const Expr* body)
251-
: ExprNodeBase(body->dtype()), var_(var), value_(value), body_(body) {}
265+
: ExprNodeBase(body->dtype(), kLet), var_(var), value_(value), body_(body) {}
252266

253267
private:
254268
const Expr* var_;

torch/csrc/jit/tensorexpr/ir_printer.cpp

+56-7
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,30 @@ template <typename Op>
2222
void visitBinaryOp(
2323
const BinaryOpNode<Op>* v,
2424
const std::string& op_str,
25-
IRPrinter* printer) {
25+
IRPrinter* printer,
26+
bool parens = true) {
2627
std::ostream& os = printer->os();
27-
os << "(";
28+
int self_prec = getPrecedence(v->expr_type());
29+
int lhs_prec = getPrecedence(v->lhs()->expr_type());
30+
int rhs_prec = getPrecedence(v->rhs()->expr_type());
31+
32+
if (lhs_prec >= self_prec) {
33+
os << "(";
34+
}
2835
v->lhs()->accept(printer);
36+
if (lhs_prec >= self_prec) {
37+
os << ")";
38+
}
39+
2940
os << " " << op_str << " ";
41+
42+
if (rhs_prec >= self_prec) {
43+
os << "(";
44+
}
3045
v->rhs()->accept(printer);
31-
os << ")";
46+
if (rhs_prec >= self_prec) {
47+
os << ")";
48+
}
3249
}
3350

3451
void IRPrinter::visit(const Add* v) {
@@ -95,8 +112,17 @@ void IRPrinter::visit(const Min* v) {
95112

96113
void IRPrinter::visit(const CompareSelect* v) {
97114
CompareSelectOperation cmp_op = v->compare_select_op();
98-
os() << "(";
115+
int self_prec = getPrecedence(v->expr_type());
116+
int lhs_prec = getPrecedence(v->lhs()->expr_type());
117+
int rhs_prec = getPrecedence(v->rhs()->expr_type());
118+
119+
if (lhs_prec >= self_prec) {
120+
os() << "(";
121+
}
99122
v->lhs()->accept(this);
123+
if (lhs_prec >= self_prec) {
124+
os() << ")";
125+
}
100126
switch (cmp_op) {
101127
case CompareSelectOperation::kEQ:
102128
os() << "==";
@@ -119,8 +145,14 @@ void IRPrinter::visit(const CompareSelect* v) {
119145
default:
120146
throw std::runtime_error("invalid compare select operator");
121147
}
148+
149+
if (rhs_prec >= self_prec) {
150+
os() << "(";
151+
}
122152
v->rhs()->accept(this);
123-
os() << ")";
153+
if (rhs_prec >= self_prec) {
154+
os() << ")";
155+
}
124156
}
125157

126158
static void formatFPSuffix(std::ostream& os, double v) {
@@ -170,13 +202,30 @@ void IRPrinter::visit(const Var* v) {
170202
}
171203

172204
void IRPrinter::visit(const Let* v) {
173-
os() << "(let ";
205+
int self_prec = getPrecedence(v->expr_type());
206+
int value_prec = getPrecedence(v->value()->expr_type());
207+
int body_prec = getPrecedence(v->body()->expr_type());
208+
os() << "let ";
174209
v->var()->accept(this);
175210
os() << " = ";
211+
212+
if (value_prec >= self_prec) {
213+
os() << "(";
214+
}
176215
v->value()->accept(this);
216+
if (value_prec >= self_prec) {
217+
os() << ")";
218+
}
219+
177220
os() << " in ";
221+
222+
if(body_prec >= self_prec) {
223+
os() << "(";
224+
}
178225
v->body()->accept(this);
179-
os() << ")";
226+
if (body_prec >= self_prec) {
227+
os() << ")";
228+
}
180229
}
181230

182231
void IRPrinter::visit(const LetStmt* v) {

0 commit comments

Comments
 (0)