Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rfactor patterns for NaN-propagating min/max #8587

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/AssociativeOpsTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,29 @@ void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<A
table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add
}

Expr is_nan_not_strict(Expr x) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_nan without strict_float can legally simplify to false, so this code threw me at first. But I realize what you're writing here is not an expression with a value - it's a pattern you're trying to pattern-match an expression against, and you don't want the strict_float to be part of the pattern because you don't want to strip it. If that's correct, please add a clarifying comment.

Type t = Bool(x.type().lanes());
if (x.type().element_of() == Float(64)) {
return Call::make(t, "is_nan_f64", {std::move(x)}, Call::PureExtern);
}
if (x.type().element_of() == Float(16)) {
return Call::make(t, "is_nan_f16", {std::move(x)}, Call::PureExtern);
}
internal_assert(x.type().element_of() == Float(32));
return Call::make(t, "is_nan_f32", {std::move(x)}, Call::PureExtern);
}

void populate_ops_table_single_float_select(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
// Propagating max operators
table.emplace_back(select(is_nan_not_strict(x0) || x0 > y0, x0, y0), tmin_0, true);
table.emplace_back(select(is_nan_not_strict(x0) || x0 >= y0, x0, y0), tmin_0, true);

// Propagating min operators
table.emplace_back(select(is_nan_not_strict(x0) || x0 < y0, x0, y0), tmax_0, true);
table.emplace_back(select(is_nan_not_strict(x0) || x0 <= y0, x0, y0), tmax_0, true);
}

const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePattern> &)> val_type_to_populate_luts_fn = {
{TableKey(ValType::All, IRNodeType::Add, 1), &populate_ops_table_single_general_add},
{TableKey(ValType::All, IRNodeType::Mul, 1), &populate_ops_table_single_general_mul},
Expand All @@ -275,6 +298,10 @@ const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePatter

{TableKey(ValType::UInt32, IRNodeType::Cast, 1), &populate_ops_table_single_uint32_cast},
{TableKey(ValType::UInt32, IRNodeType::Select, 1), &populate_ops_table_single_uint32_select},

{TableKey(ValType::Float16, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
{TableKey(ValType::Float32, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
{TableKey(ValType::Float64, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
};

const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, IRNodeType root, size_t dim) {
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct PipelineContents;
*
* The 'name' field specifies the type of Autoscheduler
* to be used (e.g. Adams2019, Mullapudi2016). If this is an empty string,
* no autoscheduling will be done; if not, it mustbe the name of a known Autoscheduler.
* no autoscheduling will be done; if not, it must be the name of a known Autoscheduler.
*
* At this time, well-known autoschedulers include:
* "Mullapudi2016" -- heuristics-based; the first working autoscheduler; currently built in to libHalide
Expand Down
46 changes: 46 additions & 0 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,51 @@ int inlined_rfactor_with_disappearing_rvar_test() {
return 0;
}

int isnan_max_rfactor_test() {
RDom r(0, 16);
RVar ri("ri");
Var x("x"), y("y"), u("u");

ImageParam in(Float(32), 2);

const auto make_reduce = [&](const char *name) {
Func reduce(name);
reduce(y) = Float(32).min();
reduce(y) = select(reduce(y) > cast(reduce.type(), in(r, y)) || is_nan(reduce(y)), reduce(y), cast(reduce.type(), in(r, y)));
return reduce;
};

Func reference = make_reduce("reference");

Func reduce = make_reduce("reduce");
reduce.update(0).split(r, r, ri, 8).rfactor(ri, u);

float tests[][16] = {
{NAN, 0.29f, 0.19f, 0.68f, 0.44f, 0.40f, 0.39f, 0.53f, 0.23f, 0.21f, 0.85f, 0.19f, 0.37f, 0.03f, 0.14f, 0.64f},
{0.98f, 0.65f, 0.86f, 0.16f, 0.14f, 0.91f, 0.74f, 0.99f, 0.91f, 0.01f, 0.11f, 0.59f, 0.05f, 0.90f, 0.93f, NAN},
{0.84f, 0.14f, 0.99f, 0.19f, 0.63f, 0.12f, 0.51f, 0.67f, NAN, 0.34f, 0.89f, 0.93f, 0.72f, 0.69f, 0.58f, 0.63f},
{0.44f, 0.12f, 0.00f, 0.30f, 0.80f, 0.88f, 0.95f, 0.12f, 0.90f, 0.99f, 0.67f, 0.71f, 0.35f, 0.67f, 0.18f, 0.93f},
};

Buffer<float, 2> buf{tests};
in.set(buf);

Buffer<float, 1> ref_vals = reference.realize({4}, get_jit_target_from_environment().with_feature(Target::StrictFloat));
Buffer<float, 1> fac_vals = reduce.realize({4}, get_jit_target_from_environment().with_feature(Target::StrictFloat));

for (int i = 0; i < 4; i++) {
if (std::isnan(fac_vals(i)) && std::isnan(ref_vals(i))) {
continue;
}
if (fac_vals(i) != ref_vals(i)) {
std::cerr << "At index " << i << ", expected: " << ref_vals(i) << ", got: " << fac_vals(i) << "\n";
return 1;
}
}

return 0;
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -1072,6 +1117,7 @@ int main(int argc, char **argv) {
{"complex multiply rfactor test", complex_multiply_rfactor_test},
{"argmin rfactor test", argmin_rfactor_test},
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
{"isnan max rfactor test", isnan_max_rfactor_test},
};

using Sharder = Halide::Internal::Test::Sharder;
Expand Down
Loading