@@ -1088,20 +1088,20 @@ class OptimizePatterns : public IRMutator {
1088
1088
}
1089
1089
}
1090
1090
1091
- template <typename NodeType, typename T >
1092
- NodeType visit_let (const T *op) {
1091
+ template <typename LetOrLetStmt >
1092
+ auto visit_let (const LetOrLetStmt *op) -> decltype(op->body ) {
1093
1093
bounds.push (op->name , bounds_of_expr_in_scope (op->value , bounds));
1094
- NodeType node = IRMutator::visit (op);
1094
+ auto node = IRMutator::visit (op);
1095
1095
bounds.pop (op->name );
1096
1096
return node;
1097
1097
}
1098
1098
1099
1099
Expr visit (const Let *op) override {
1100
- return visit_let<Expr> (op);
1100
+ return visit_let (op);
1101
1101
}
1102
1102
1103
1103
Stmt visit (const LetStmt *op) override {
1104
- return visit_let<Stmt> (op);
1104
+ return visit_let (op);
1105
1105
}
1106
1106
1107
1107
Expr visit (const Div *op) override {
@@ -1599,12 +1599,12 @@ class EliminateInterleaves : public IRMutator {
1599
1599
}
1600
1600
}
1601
1601
1602
- template <typename NodeType, typename LetType >
1603
- NodeType visit_let (const LetType *op) {
1602
+ template <typename LetOrLetStmt >
1603
+ auto visit_let (const LetOrLetStmt *op) -> decltype(op->body ) {
1604
1604
1605
1605
Expr value = mutate (op->value );
1606
1606
string deinterleaved_name;
1607
- NodeType body;
1607
+ decltype (op-> body ) body;
1608
1608
// Other code in this mutator needs to be able to tell the
1609
1609
// difference between a Let that yields a deinterleave, and a
1610
1610
// let that has a removable deinterleave. Lets that can
@@ -1632,10 +1632,10 @@ class EliminateInterleaves : public IRMutator {
1632
1632
return op;
1633
1633
} else if (body.same_as (op->body )) {
1634
1634
// If the body didn't change, we must not have used the deinterleaved value.
1635
- return LetType ::make (op->name , value, body);
1635
+ return LetOrLetStmt ::make (op->name , value, body);
1636
1636
} else {
1637
1637
// We need to rewrap the body with new lets.
1638
- NodeType result = body;
1638
+ auto result = body;
1639
1639
bool deinterleaved_used = stmt_or_expr_uses_var (result, deinterleaved_name);
1640
1640
bool interleaved_used = stmt_or_expr_uses_var (result, op->name );
1641
1641
if (deinterleaved_used && interleaved_used) {
@@ -1653,14 +1653,14 @@ class EliminateInterleaves : public IRMutator {
1653
1653
interleaved = native_interleave (interleaved);
1654
1654
}
1655
1655
1656
- result = LetType ::make (op->name , interleaved, result);
1657
- return LetType ::make (deinterleaved_name, deinterleaved, result);
1656
+ result = LetOrLetStmt ::make (op->name , interleaved, result);
1657
+ return LetOrLetStmt ::make (deinterleaved_name, deinterleaved, result);
1658
1658
} else if (deinterleaved_used) {
1659
1659
// Only the deinterleaved value is used, we can eliminate the interleave.
1660
- return LetType ::make (deinterleaved_name, remove_interleave (value), result);
1660
+ return LetOrLetStmt ::make (deinterleaved_name, remove_interleave (value), result);
1661
1661
} else if (interleaved_used) {
1662
1662
// Only the original value is used, regenerate the let.
1663
- return LetType ::make (op->name , value, result);
1663
+ return LetOrLetStmt ::make (op->name , value, result);
1664
1664
} else {
1665
1665
// The let must have been dead.
1666
1666
internal_assert (!stmt_or_expr_uses_var (op->body , op->name ))
@@ -1671,7 +1671,7 @@ class EliminateInterleaves : public IRMutator {
1671
1671
}
1672
1672
1673
1673
Expr visit (const Let *op) override {
1674
- Expr expr = visit_let<Expr> (op);
1674
+ Expr expr = visit_let (op);
1675
1675
1676
1676
// Lift interleaves out of Let expression bodies.
1677
1677
const Let *let = expr.as <Let>();
@@ -1682,7 +1682,7 @@ class EliminateInterleaves : public IRMutator {
1682
1682
}
1683
1683
1684
1684
Stmt visit (const LetStmt *op) override {
1685
- return visit_let<Stmt> (op);
1685
+ return visit_let (op);
1686
1686
}
1687
1687
1688
1688
Expr visit (const Cast *op) override {
@@ -2047,25 +2047,25 @@ class ScatterGatherGenerator : public IRMutator {
2047
2047
return IRMutator::visit (op);
2048
2048
}
2049
2049
2050
- template <typename NodeType, typename T >
2051
- NodeType visit_let (const T *op) {
2050
+ template <typename LetOrLetStmt >
2051
+ auto visit_let (const LetOrLetStmt *op) -> decltype(op->body ) {
2052
2052
// We only care about vector lets.
2053
2053
if (op->value .type ().is_vector ()) {
2054
2054
bounds.push (op->name , bounds_of_expr_in_scope (op->value , bounds));
2055
2055
}
2056
- NodeType node = IRMutator::visit (op);
2056
+ auto node = IRMutator::visit (op);
2057
2057
if (op->value .type ().is_vector ()) {
2058
2058
bounds.pop (op->name );
2059
2059
}
2060
2060
return node;
2061
2061
}
2062
2062
2063
2063
Expr visit (const Let *op) override {
2064
- return visit_let<Expr> (op);
2064
+ return visit_let (op);
2065
2065
}
2066
2066
2067
2067
Stmt visit (const LetStmt *op) override {
2068
- return visit_let<Stmt> (op);
2068
+ return visit_let (op);
2069
2069
}
2070
2070
2071
2071
Stmt visit (const Allocate *op) override {
0 commit comments