Skip to content

Commit 097aee9

Browse files
authored
Use a consistent idiom for visit_let (#8540)
visit_let in the codebase uses a wide variety of template names, argument names, and ways of getting the body type. This just picks one and uses it consistently. No functional changes.
1 parent c2d5ea3 commit 097aee9

14 files changed

+120
-120
lines changed

src/BoundSmallAllocations.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,40 @@ class BoundSmallAllocations : public IRMutator {
1717
// Track constant bounds
1818
Scope<Interval> scope;
1919

20-
template<typename T, typename Body>
21-
Body visit_let(const T *op) {
20+
template<typename LetOrLetStmt>
21+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
2222
// Visit an entire chain of lets in a single method to conserve stack space.
2323
struct Frame {
24-
const T *op;
24+
const LetOrLetStmt *op;
2525
ScopedBinding<Interval> binding;
26-
Frame(const T *op, Scope<Interval> &scope)
26+
Frame(const LetOrLetStmt *op, Scope<Interval> &scope)
2727
: op(op),
2828
binding(scope, op->name, find_constant_bounds(op->value, scope)) {
2929
}
3030
};
3131
std::vector<Frame> frames;
32-
Body result;
32+
decltype(op->body) result;
3333

3434
do {
3535
result = op->body;
3636
frames.emplace_back(op, scope);
37-
} while ((op = result.template as<T>()));
37+
} while ((op = result.template as<LetOrLetStmt>()));
3838

3939
result = mutate(result);
4040

4141
for (const auto &frame : reverse_view(frames)) {
42-
result = T::make(frame.op->name, frame.op->value, result);
42+
result = LetOrLetStmt::make(frame.op->name, frame.op->value, result);
4343
}
4444

4545
return result;
4646
}
4747

4848
Stmt visit(const LetStmt *op) override {
49-
return visit_let<LetStmt, Stmt>(op);
49+
return visit_let(op);
5050
}
5151

5252
Expr visit(const Let *op) override {
53-
return visit_let<Let, Expr>(op);
53+
return visit_let(op);
5454
}
5555

5656
bool in_thread_loop = false;

src/ClampUnsafeAccesses.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ struct ClampUnsafeAccesses : IRMutator {
4242
}
4343

4444
Expr visit(const Let *let) override {
45-
return visit_let<Let, Expr>(let);
45+
return visit_let(let);
4646
}
4747

4848
Stmt visit(const LetStmt *let) override {
49-
return visit_let<LetStmt, Stmt>(let);
49+
return visit_let(let);
5050
}
5151

5252
Expr visit(const Variable *var) override {
@@ -80,15 +80,15 @@ struct ClampUnsafeAccesses : IRMutator {
8080
}
8181

8282
private:
83-
template<typename L, typename Body>
84-
Body visit_let(const L *let) {
85-
ScopedBinding<bool> binding(let_var_inside_indexing, let->name, false);
86-
Body body = mutate(let->body);
83+
template<typename LetOrLetStmt>
84+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
85+
ScopedBinding<bool> binding(let_var_inside_indexing, op->name, false);
86+
auto body = mutate(op->body);
8787

88-
ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name));
89-
Expr value = mutate(let->value);
88+
ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(op->name));
89+
Expr value = mutate(op->value);
9090

91-
return L::make(let->name, std::move(value), std::move(body));
91+
return LetOrLetStmt::make(op->name, std::move(value), std::move(body));
9292
}
9393

9494
bool bounds_smaller_than_type(const Interval &bounds, Type type) {

src/Deinterleave.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -465,44 +465,44 @@ class Interleaver : public IRMutator {
465465
return Shuffle::make_interleave(exprs);
466466
}
467467

468-
template<typename T, typename Body>
469-
Body visit_lets(const T *op) {
468+
template<typename LetOrLetStmt>
469+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
470470
// Visit an entire chain of lets in a single method to conserve stack space.
471471
struct Frame {
472-
const T *op;
472+
const LetOrLetStmt *op;
473473
Expr new_value;
474474
ScopedBinding<> binding;
475-
Frame(const T *op, Expr v, Scope<void> &scope)
475+
Frame(const LetOrLetStmt *op, Expr v, Scope<void> &scope)
476476
: op(op),
477477
new_value(std::move(v)),
478478
binding(new_value.type().is_vector(), scope, op->name) {
479479
}
480480
};
481481
std::vector<Frame> frames;
482-
Body result;
482+
decltype(op->body) result;
483483

484484
do {
485485
result = op->body;
486486
frames.emplace_back(op, mutate(op->value), vector_lets);
487-
} while ((op = result.template as<T>()));
487+
} while ((op = result.template as<LetOrLetStmt>()));
488488

489489
result = mutate(result);
490490

491491
for (const auto &frame : reverse_view(frames)) {
492492
Expr value = std::move(frame.new_value);
493493

494-
result = T::make(frame.op->name, value, result);
494+
result = LetOrLetStmt::make(frame.op->name, value, result);
495495

496496
// For vector lets, we may additionally need a let defining the even and odd lanes only
497497
if (value.type().is_vector()) {
498498
if (value.type().lanes() % 2 == 0) {
499-
result = T::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
500-
result = T::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
499+
result = LetOrLetStmt::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
500+
result = LetOrLetStmt::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
501501
}
502502
if (value.type().lanes() % 3 == 0) {
503-
result = T::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
504-
result = T::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
505-
result = T::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
503+
result = LetOrLetStmt::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
504+
result = LetOrLetStmt::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
505+
result = LetOrLetStmt::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
506506
}
507507
}
508508
}
@@ -511,11 +511,11 @@ class Interleaver : public IRMutator {
511511
}
512512

513513
Expr visit(const Let *op) override {
514-
return visit_lets<Let, Expr>(op);
514+
return visit_let(op);
515515
}
516516

517517
Stmt visit(const LetStmt *op) override {
518-
return visit_lets<LetStmt, Stmt>(op);
518+
return visit_let(op);
519519
}
520520

521521
Expr visit(const Ramp *op) override {

src/EliminateBoolVectors.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ class EliminateBoolVectors : public IRMutator {
287287
return expr;
288288
}
289289

290-
template<typename NodeType, typename LetType>
291-
NodeType visit_let(const LetType *op) {
290+
template<typename LetOrLetStmt>
291+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
292292
Expr value = mutate(op->value);
293293

294294
// We changed the type of the let, we need to replace the
@@ -305,17 +305,17 @@ class EliminateBoolVectors : public IRMutator {
305305
}
306306

307307
if (!value.same_as(op->value) || !body.same_as(op->body)) {
308-
return LetType::make(op->name, value, body);
308+
return LetOrLetStmt::make(op->name, value, body);
309309
} else {
310310
return op;
311311
}
312312
}
313313

314314
Expr visit(const Let *op) override {
315-
return visit_let<Expr>(op);
315+
return visit_let(op);
316316
}
317317
Stmt visit(const LetStmt *op) override {
318-
return visit_let<Stmt>(op);
318+
return visit_let(op);
319319
}
320320
};
321321

src/FuseGPUThreadLoops.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -1157,9 +1157,9 @@ class ExtractRegisterAllocations : public IRMutator {
11571157
op->param, mutate(op->predicate), op->alignment);
11581158
}
11591159

1160-
template<typename ExprOrStmt, typename LetOrLetStmt>
1161-
ExprOrStmt visit_let(const LetOrLetStmt *op) {
1162-
ExprOrStmt body = op->body;
1160+
template<typename LetOrLetStmt>
1161+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
1162+
auto body = op->body;
11631163

11641164
body = mutate(op->body);
11651165
Expr value = mutate(op->value);
@@ -1178,11 +1178,11 @@ class ExtractRegisterAllocations : public IRMutator {
11781178
}
11791179

11801180
Expr visit(const Let *op) override {
1181-
return visit_let<Expr>(op);
1181+
return visit_let(op);
11821182
}
11831183

11841184
Stmt visit(const LetStmt *op) override {
1185-
return visit_let<Stmt>(op);
1185+
return visit_let(op);
11861186
}
11871187

11881188
Scope<int> register_allocations;

src/HexagonOptimize.cpp

+21-21
Original file line numberDiff line numberDiff line change
@@ -1088,20 +1088,20 @@ class OptimizePatterns : public IRMutator {
10881088
}
10891089
}
10901090

1091-
template<typename NodeType, typename T>
1092-
NodeType visit_let(const T *op) {
1091+
template<typename LetOrLetStmt>
1092+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
10931093
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
1094-
NodeType node = IRMutator::visit(op);
1094+
auto node = IRMutator::visit(op);
10951095
bounds.pop(op->name);
10961096
return node;
10971097
}
10981098

10991099
Expr visit(const Let *op) override {
1100-
return visit_let<Expr>(op);
1100+
return visit_let(op);
11011101
}
11021102

11031103
Stmt visit(const LetStmt *op) override {
1104-
return visit_let<Stmt>(op);
1104+
return visit_let(op);
11051105
}
11061106

11071107
Expr visit(const Div *op) override {
@@ -1599,12 +1599,12 @@ class EliminateInterleaves : public IRMutator {
15991599
}
16001600
}
16011601

1602-
template<typename NodeType, typename LetType>
1603-
NodeType visit_let(const LetType *op) {
1602+
template<typename LetOrLetStmt>
1603+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
16041604

16051605
Expr value = mutate(op->value);
16061606
string deinterleaved_name;
1607-
NodeType body;
1607+
decltype(op->body) body;
16081608
// Other code in this mutator needs to be able to tell the
16091609
// difference between a Let that yields a deinterleave, and a
16101610
// let that has a removable deinterleave. Lets that can
@@ -1632,10 +1632,10 @@ class EliminateInterleaves : public IRMutator {
16321632
return op;
16331633
} else if (body.same_as(op->body)) {
16341634
// If the body didn't change, we must not have used the deinterleaved value.
1635-
return LetType::make(op->name, value, body);
1635+
return LetOrLetStmt::make(op->name, value, body);
16361636
} else {
16371637
// We need to rewrap the body with new lets.
1638-
NodeType result = body;
1638+
auto result = body;
16391639
bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name);
16401640
bool interleaved_used = stmt_or_expr_uses_var(result, op->name);
16411641
if (deinterleaved_used && interleaved_used) {
@@ -1653,14 +1653,14 @@ class EliminateInterleaves : public IRMutator {
16531653
interleaved = native_interleave(interleaved);
16541654
}
16551655

1656-
result = LetType::make(op->name, interleaved, result);
1657-
return LetType::make(deinterleaved_name, deinterleaved, result);
1656+
result = LetOrLetStmt::make(op->name, interleaved, result);
1657+
return LetOrLetStmt::make(deinterleaved_name, deinterleaved, result);
16581658
} else if (deinterleaved_used) {
16591659
// Only the deinterleaved value is used, we can eliminate the interleave.
1660-
return LetType::make(deinterleaved_name, remove_interleave(value), result);
1660+
return LetOrLetStmt::make(deinterleaved_name, remove_interleave(value), result);
16611661
} else if (interleaved_used) {
16621662
// Only the original value is used, regenerate the let.
1663-
return LetType::make(op->name, value, result);
1663+
return LetOrLetStmt::make(op->name, value, result);
16641664
} else {
16651665
// The let must have been dead.
16661666
internal_assert(!stmt_or_expr_uses_var(op->body, op->name))
@@ -1671,7 +1671,7 @@ class EliminateInterleaves : public IRMutator {
16711671
}
16721672

16731673
Expr visit(const Let *op) override {
1674-
Expr expr = visit_let<Expr>(op);
1674+
Expr expr = visit_let(op);
16751675

16761676
// Lift interleaves out of Let expression bodies.
16771677
const Let *let = expr.as<Let>();
@@ -1682,7 +1682,7 @@ class EliminateInterleaves : public IRMutator {
16821682
}
16831683

16841684
Stmt visit(const LetStmt *op) override {
1685-
return visit_let<Stmt>(op);
1685+
return visit_let(op);
16861686
}
16871687

16881688
Expr visit(const Cast *op) override {
@@ -2047,25 +2047,25 @@ class ScatterGatherGenerator : public IRMutator {
20472047
return IRMutator::visit(op);
20482048
}
20492049

2050-
template<typename NodeType, typename T>
2051-
NodeType visit_let(const T *op) {
2050+
template<typename LetOrLetStmt>
2051+
auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
20522052
// We only care about vector lets.
20532053
if (op->value.type().is_vector()) {
20542054
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
20552055
}
2056-
NodeType node = IRMutator::visit(op);
2056+
auto node = IRMutator::visit(op);
20572057
if (op->value.type().is_vector()) {
20582058
bounds.pop(op->name);
20592059
}
20602060
return node;
20612061
}
20622062

20632063
Expr visit(const Let *op) override {
2064-
return visit_let<Expr>(op);
2064+
return visit_let(op);
20652065
}
20662066

20672067
Stmt visit(const LetStmt *op) override {
2068-
return visit_let<Stmt>(op);
2068+
return visit_let(op);
20692069
}
20702070

20712071
Stmt visit(const Allocate *op) override {

0 commit comments

Comments
 (0)