Skip to content

Commit e41fd47

Browse files
authored
Improve argmax patterns for rfactor (#8863)
The `argmax` and `argmin` patterns in AssociativeOpsTable assumed that comparisons would be normalized to `y < x` (or `x < y`), but the Solve module could flip these. It was also possible an earlier pass would negate the patterns such that `y >= x` (or `x >= y`) with the branches reversed could be encountered instead. To solve this issue, we extend Solve to move the variable of interest to the `true` branch of a `select`. We also extend the table to cover all four comparison cases. We also substantially refactor Solve.cpp to use a helper to track and update the ambient tree traversal state (`uses_var` and `failed`).
1 parent d40a56e commit e41fd47

File tree

5 files changed

+193
-187
lines changed

5 files changed

+193
-187
lines changed

src/AssociativeOpsTable.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,20 @@ void populate_ops_table_double_general_mul(const vector<Type> &types, vector<Ass
177177

178178
void populate_ops_table_double_general_max(const vector<Type> &types, vector<AssociativePattern> &table) {
179179
declare_vars_double(types);
180+
// Argmax
181+
table.push_back({{max(x0, y0), select(x0 > y0, x1, y1)}, {tmin_0, zero_1}, true});
182+
table.push_back({{max(x0, y0), select(x0 >= y0, x1, y1)}, {tmin_0, zero_1}, true});
180183
table.push_back({{max(x0, y0), select(y0 < x0, x1, y1)}, {tmin_0, zero_1}, true});
184+
table.push_back({{max(x0, y0), select(y0 <= x0, x1, y1)}, {tmin_0, zero_1}, true});
181185
}
182186

183187
void populate_ops_table_double_general_min(const vector<Type> &types, vector<AssociativePattern> &table) {
184188
declare_vars_double(types);
189+
// Argmin
185190
table.push_back({{min(x0, y0), select(x0 < y0, x1, y1)}, {tmax_0, zero_1}, true});
191+
table.push_back({{min(x0, y0), select(x0 <= y0, x1, y1)}, {tmax_0, zero_1}, true});
192+
table.push_back({{min(x0, y0), select(y0 > x0, x1, y1)}, {tmax_0, zero_1}, true});
193+
table.push_back({{min(x0, y0), select(y0 >= x0, x1, y1)}, {tmax_0, zero_1}, true});
186194
}
187195

188196
void populate_ops_table_double_general_sub(const vector<Type> &types, vector<AssociativePattern> &table) {
@@ -348,21 +356,21 @@ const vector<AssociativePattern> &get_ops_table(const vector<Expr> &exprs) {
348356
types[i] = exprs[i].type();
349357
}
350358

351-
{
359+
const vector<AssociativePattern> &table = [&]() -> decltype(auto) {
352360
// get_ops_table_helper() lazily initializes the table, so ensure
353361
// that multiple threads can't try to do so at the same time.
354362
static std::mutex ops_table_lock;
355363
std::scoped_lock lock_guard(ops_table_lock);
356364

357-
const vector<AssociativePattern> &table = get_ops_table_helper(types, exprs[0].node_type(), exprs.size());
358-
debug(7) << "Table size: " << table.size() << "\n";
359-
for (const auto &p : table) {
360-
debug(7) << p;
361-
}
362-
return table;
365+
return get_ops_table_helper(types, exprs[0].node_type(), exprs.size());
366+
}();
367+
368+
debug(5) << "Table size: " << table.size() << "\n";
369+
for (const auto &p : table) {
370+
debug(5) << p;
363371
}
364-
debug(5) << "Returning empty table\n";
365-
return empty;
372+
373+
return table;
366374
}
367375

368376
} // namespace Internal

src/CSE.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ void cse_test() {
526526
check(e, correct);
527527
}
528528

529-
debug(0) << "common_subexpression_elimination test passed\n";
529+
std::cout << "common_subexpression_elimination test passed\n";
530530
}
531531

532532
} // namespace Internal

src/IREquality.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ void ir_equality_test() {
608608
e2 = e2 * e2 + e2;
609609
check_not_equal(e1, e2);
610610

611-
debug(0) << "ir_equality_test passed\n";
611+
std::cout << "ir_equality_test passed\n";
612612
}
613613

614614
} // namespace Internal

0 commit comments

Comments
 (0)