Skip to content

Commit d15325e

Browse files
authored
Cherry-pick some recent bug-fixes into 17.0.1 (#8107)
* Fix rfactor adding too many pure loops (#8086) When you rfactor an update definition, the new update definition must use all the pure vars of the Func, even though the one you're rfactoring may not have used them all. We also want to preserve any scheduling already done to the pure vars, so we want to preserve the dims list and splits list from the original definition. The code accounted for this by checking the dims list for any missing pure vars and adding them at the end (just before Var::outermost()), but this didn't account for the fact that they may no longer exist in the dims list due to splits that didn't reuse the outer name. In these circumstances we could end up with too many pure loops. E.g. if x has been split into xo and xi, then the code was adding a loop for x even though there were already loops for xo and xi, which of course produces garbage output. This PR instead just checks which pure vars are actually used in the update definition up front, and then uses that to tell which ones should be added. Fixes #7890 * Forward the partition methods from generator outputs (#8090) * Fix reduce_expr_modulo of vector in Solve.cpp (#8089) * Fix reduce_expr_modulo of vector in Solve.cpp * Fix test
1 parent 8f424e5 commit d15325e

File tree

4 files changed

+57
-4
lines changed

4 files changed

+57
-4
lines changed

src/Func.cpp

+23-3
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,17 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
788788
vector<Expr> &args = definition.args();
789789
vector<Expr> &values = definition.values();
790790

791+
// Figure out which pure vars were used in this update definition.
792+
std::set<string> pure_vars_used;
793+
internal_assert(args.size() == dim_vars.size());
794+
for (size_t i = 0; i < args.size(); i++) {
795+
if (const Internal::Variable *var = args[i].as<Variable>()) {
796+
if (var->name == dim_vars[i].name()) {
797+
pure_vars_used.insert(var->name);
798+
}
799+
}
800+
}
801+
791802
// Check whether the operator is associative and determine the operator and
792803
// its identity for each value in the definition if it is a Tuple
793804
const auto &prover_result = prove_associativity(func_name, args, values);
@@ -1012,16 +1023,20 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
10121023

10131024
// Determine the dims of the new update definition
10141025

1026+
// The new update definition needs all the pure vars of the Func, but the
1027+
// one we're rfactoring may not have used them all. Add any missing ones to
1028+
// the dims list.
1029+
10151030
// Add pure Vars from the original init definition to the dims list
10161031
// if they are not already in the list
10171032
for (const Var &v : dim_vars) {
1018-
const auto &iter = std::find_if(dims.begin(), dims.end(),
1019-
[&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
1020-
if (iter == dims.end()) {
1033+
if (!pure_vars_used.count(v.name())) {
10211034
Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto};
1035+
// Insert it just before Var::outermost
10221036
dims.insert(dims.end() - 1, d);
10231037
}
10241038
}
1039+
10251040
// Then, we need to remove lifted RVars from the dims list
10261041
for (const string &rv : rvars_removed) {
10271042
remove(rv);
@@ -1888,6 +1903,11 @@ Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {
18881903

18891904
dims_old.swap(dims);
18901905

1906+
// We're not allowed to reorder Var::outermost inwards (rfactor assumes it's
1907+
// the last one).
1908+
user_assert(dims.back().var == Var::outermost().name())
1909+
<< "Var::outermost() may not be reordered inside any other var.\n";
1910+
18911911
return *this;
18921912
}
18931913

src/Generator.h

+5
Original file line numberDiff line numberDiff line change
@@ -2280,6 +2280,8 @@ class GeneratorOutputBase : public GIOBase {
22802280
HALIDE_FORWARD_METHOD(Func, align_bounds)
22812281
HALIDE_FORWARD_METHOD(Func, align_extent)
22822282
HALIDE_FORWARD_METHOD(Func, align_storage)
2283+
HALIDE_FORWARD_METHOD(Func, always_partition)
2284+
HALIDE_FORWARD_METHOD(Func, always_partition_all)
22832285
HALIDE_FORWARD_METHOD_CONST(Func, args)
22842286
HALIDE_FORWARD_METHOD(Func, bound)
22852287
HALIDE_FORWARD_METHOD(Func, bound_extent)
@@ -2303,9 +2305,12 @@ class GeneratorOutputBase : public GIOBase {
23032305
HALIDE_FORWARD_METHOD(Func, hexagon)
23042306
HALIDE_FORWARD_METHOD(Func, in)
23052307
HALIDE_FORWARD_METHOD(Func, memoize)
2308+
HALIDE_FORWARD_METHOD(Func, never_partition)
2309+
HALIDE_FORWARD_METHOD(Func, never_partition_all)
23062310
HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
23072311
HALIDE_FORWARD_METHOD_CONST(Func, outputs)
23082312
HALIDE_FORWARD_METHOD(Func, parallel)
2313+
HALIDE_FORWARD_METHOD(Func, partition)
23092314
HALIDE_FORWARD_METHOD(Func, prefetch)
23102315
HALIDE_FORWARD_METHOD(Func, print_loop_nest)
23112316
HALIDE_FORWARD_METHOD(Func, rename)

src/Solve.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class SolveExpression : public IRMutator {
394394
if (a_uses_var && !b_uses_var) {
395395
const int64_t *ib = as_const_int(b);
396396
auto is_multiple_of_b = [&](const Expr &e) {
397-
if (ib) {
397+
if (ib && op->type.is_scalar()) {
398398
int64_t r = 0;
399399
return reduce_expr_modulo(e, *ib, &r) && r == 0;
400400
} else {
@@ -1478,6 +1478,9 @@ void solve_test() {
14781478
check_solve(min(x + y, x - z), x + min(y, 0 - z));
14791479
check_solve(max(x + y, x - z), x + max(y, 0 - z));
14801480

1481+
check_solve((5 * Broadcast::make(x, 4) + y) / 5,
1482+
Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5));
1483+
14811484
debug(0) << "Solve test passed\n";
14821485
}
14831486

test/correctness/fuzz_schedule.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,31 @@ int main(int argc, char **argv) {
202202
check_blur_output(buf, correct);
203203
}
204204

205+
// https://github.com/halide/Halide/issues/7890
206+
{
207+
Func input("input");
208+
Func local_sum("local_sum");
209+
Func blurry("blurry");
210+
Var x("x"), y("y");
211+
RVar yryf;
212+
input(x, y) = 2 * x + 5 * y;
213+
RDom r(-2, 5, -2, 5, "rdom_r");
214+
local_sum(x, y) = 0;
215+
local_sum(x, y) += input(x + r.x, y + r.y);
216+
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);
217+
218+
Var yo, yi, xo, xi, u;
219+
blurry.split(y, yo, yi, 2, TailStrategy::Auto);
220+
local_sum.split(x, xo, xi, 4, TailStrategy::Auto);
221+
local_sum.update(0).split(x, xo, xi, 1, TailStrategy::Auto);
222+
local_sum.update(0).rfactor(r.x, u);
223+
blurry.store_root();
224+
local_sum.compute_root();
225+
Pipeline p({blurry});
226+
auto buf = p.realize({32, 32});
227+
check_blur_output(buf, correct);
228+
}
229+
205230
// https://github.com/halide/Halide/issues/8054
206231
{
207232
ImageParam input(Float(32), 2, "input");

0 commit comments

Comments
 (0)