diff --git a/doc/RunGen.md b/doc/RunGen.md index 6dde6a42d498..968d4a345c4f 100644 --- a/doc/RunGen.md +++ b/doc/RunGen.md @@ -125,7 +125,7 @@ Generator, and inits every element to zero: ``` # Input is a 3-dimensional image with extent 123, 456, and 3 -# (bluring an image of all zeroes isn't very interesting, of course) +# (bluring an image of all zeros isn't very interesting, of course) $ ./bin/local_laplacian.rungen --output_extents=[100,200,3] input=zero:[123,456,3] levels=8 alpha=1 beta=1 output=/tmp/out.png ``` diff --git a/src/ConstantBounds.cpp b/src/ConstantBounds.cpp index 164678e2554c..bf228f2c86a4 100644 --- a/src/ConstantBounds.cpp +++ b/src/ConstantBounds.cpp @@ -125,6 +125,10 @@ ConstantInterval bounds_helper(const Expr &e, ConstantInterval cq = recurse(op->args[2]); ConstantInterval rounding_term = 1 << (cq - 1); return (ca * cb + rounding_term) >> cq; + } else if (op->is_intrinsic(Call::bitwise_not)) { + // We can't do much with the other bitwise ops, but we can treat + // bitwise_not as an all-ones bit pattern minus the argument. + return recurse(make_const(e.type(), -1) - op->args[0]); } // If you add a new intrinsic here, also add it to the expression // generator in test/correctness/lossless_cast.cpp diff --git a/src/HexagonOffload.cpp b/src/HexagonOffload.cpp index 2d7d5e74acf7..e3d3664a54c1 100644 --- a/src/HexagonOffload.cpp +++ b/src/HexagonOffload.cpp @@ -1049,7 +1049,7 @@ Buffer compile_module_to_hexagon_shared_object(const Module &device_cod // This will cause a difference in MemSize and FileSize like so: // FileSize = (MemSize - size_of_bss) // When the Hexagon loader is used on 8998 and later targets, - // the difference is filled with zeroes thereby initializing the .bss + // the difference is filled with zeros thereby initializing the .bss // section. bss->set_type(Elf::Section::SHT_PROGBITS); std::fill(bss->contents_begin(), bss->contents_end(), 0); diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 494ca5665f91..b597b8b5286d 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -91,7 +91,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::VarInfo info; info.old_uses = info.new_uses = 0; if (const Variable *v = fact.as()) { - info.replacement = const_false(fact.type().lanes()); + info.replacement = Halide::Internal::const_false(fact.type().lanes()); simplify->var_info.push(v->name, info); pop_list.push_back(v); } else if (const NE *ne = fact.as()) { @@ -178,7 +178,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::VarInfo info; info.old_uses = info.new_uses = 0; if (const Variable *v = fact.as()) { - info.replacement = const_true(fact.type().lanes()); + info.replacement = Halide::Internal::const_true(fact.type().lanes()); simplify->var_info.push(v->name, info); pop_list.push_back(v); } else if (const EQ *eq = fact.as()) { @@ -496,5 +496,158 @@ bool can_prove(Expr e, const Scope &bounds) { return is_const_one(e); } +Simplify::ExprInfo::BitsKnown Simplify::ExprInfo::to_bits_known(const Type &type) const { + BitsKnown result = {0, 0}; + + if (!(type.is_int() || type.is_uint())) { + // Let's not claim we know anything about the bit patterns of + // non-integer types for now. + return result; + } + + // Identify the largest power of two in the modulus to get some low bits + if (alignment.modulus) { + result.mask = largest_power_of_two_factor(alignment.modulus) - 1; + result.value = result.mask & alignment.remainder; + } else { + // This value is just a constant + result.mask = (uint64_t)(-1); + result.value = alignment.remainder; + return result; + } + + // Compute a mask which is 1 for all the leading zeros of a uint64 + auto leading_zeros_mask = [](uint64_t x) { + if (x == 0) { + // They're all leading zeros, but clz64 is UB on zero. Really we + // should have returned early above, but it's hard to guarantee that + // the alignment analysis catches constants at the same time as + // bounds analysis does. + return (uint64_t)-1; + } + return (uint64_t)(-1) << (64 - clz64(x)); + }; + + auto leading_ones_mask = [=](uint64_t x) { + return leading_zeros_mask(~x); + }; + + // The bounds and the type tell us a bunch of high bits are zero or one + if (type.is_uint()) { + // Narrow uints are always zero-extended. + if (type.bits() < 64) { + result.mask |= (uint64_t)(-1) << type.bits(); + } + // A lower bound might tell us that there are some leading ones, and an + // upper bound might tell us that there are some leading + // zeros. Unfortunately we'll never learn about leading ones, because + // uint64_ts that start with leading ones can't have a min represented + // as an int64_t, which is what ConstantInverval uses, so + // bounds.min_defined will never be true. + if (bounds.max_defined) { + result.mask |= leading_zeros_mask(bounds.max); + } + } else { + internal_assert(type.is_int()); + // A mask which is 1 for the sign bit and above. + uint64_t sign_bit_and_above = (uint64_t)(-1) << (type.bits() - 1); + if (bounds >= 0) { + // We know this int is positive, so the sign bit and above are zero. + result.mask |= sign_bit_and_above; + if (bounds.max_defined) { + // We also have an upper bound, so there may be more zero bits, + // depending on how many leading zeros there are in the upper + // bound. + result.mask |= leading_zeros_mask(bounds.max); + } + } else if (bounds < 0) { + // This int is negative, so the sign bit and above are one. + result.mask |= sign_bit_and_above; + result.value |= sign_bit_and_above; + if (bounds.min_defined) { + // We have a lower bound, so there may be more leading one bits, + // depending on how many leading ones there are in the lower + // bound. + uint64_t leading_ones = leading_ones_mask(bounds.min); + result.mask |= leading_ones; + result.value |= leading_ones; + } + } + } + + return result; +} + +void Simplify::ExprInfo::from_bits_known(Simplify::ExprInfo::BitsKnown known, const Type &type) { + // Normalize everything to 64-bits by sign- or zero-extending known bits for + // the type. + + // A mask which is one for all the new bits resulting from sign or zero + // extension. + uint64_t missing_bits = 0; + if (type.bits() < 64) { + missing_bits = (uint64_t)(-1) << type.bits(); + } + + if (missing_bits) { + if (type.is_uint()) { + // For a uint the high bits are known to be zero + known.mask |= missing_bits; + known.value &= ~missing_bits; + } else if (type.is_int()) { + // For an int we need to know the sign to know the high bits + bool sign_bit_known = (known.mask >> (type.bits() - 1)) & 1; + bool negative = (known.value >> (type.bits() - 1)) & 1; + if (!sign_bit_known) { + // We don't know the sign bit, so we don't know any of the + // extended bits. Mark them as unknown in the mask and zero them + // out in the value too just for ease of debugging. + known.mask &= ~missing_bits; + known.value &= ~missing_bits; + } else if (negative) { + // We know the sign bit is 1, so all of the extended bits are 1 + // too. + known.mask |= missing_bits; + known.value |= missing_bits; + } else if (!negative) { + // We know the sign bit is zero, so all of the extended bits are + // zero too. + known.mask |= missing_bits; + known.value &= ~missing_bits; + } + } + } + + // We can get the trailing one bits by adding one and taking the largest + // power of two factor. Note that this works out correctly when we know all + // the bits - the modulus comes out as zero, and the remainder is the entire + // number, which is how we represent constants in ModulusRemainder. + alignment.modulus = largest_power_of_two_factor(known.mask + 1); + alignment.remainder = known.value & (alignment.modulus - 1); + + if ((int64_t)known.mask < 0) { + // We know some leading bits + + // Set all unknown bits to zero + uint64_t min_val = known.value & known.mask; + // Set all unknown bits to one + uint64_t max_val = known.value | ~known.mask; + + if (type.is_uint() && (int64_t)known.value < 0) { + // We know it's out of range at the top end for our ConstantInterval + // class. At the time of writing, to_bits_known can't produce this + // directly, and bits_known is never propagated through other + // operations, so this code is unreachable. Nonetheless we'll do the + // best job we can at representing this case in case this code + // becomes reachable in future. + bounds = ConstantInterval::bounded_below((1ULL << 63) - 1); + } else { + // In all other cases, the bounds are representable as an int64 + // and don't span zero (because we know the high bit). + bounds = ConstantInterval{(int64_t)min_val, (int64_t)max_val}; + } + } +} + } // namespace Internal } // namespace Halide diff --git a/src/Simplify_And.cpp b/src/Simplify_And.cpp index a6f7e82c9095..3a78467c296d 100644 --- a/src/Simplify_And.cpp +++ b/src/Simplify_And.cpp @@ -5,7 +5,7 @@ namespace Internal { Expr Simplify::visit(const And *op, ExprInfo *info) { if (falsehoods.count(op)) { - return const_false(op->type.lanes()); + return const_false(op->type.lanes(), info); } Expr a = mutate(op->a, nullptr); @@ -16,33 +16,17 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { std::swap(a, b); } + if (info) { + info->cast_to(op->type); + } + auto rewrite = IRMatcher::rewriter(IRMatcher::and_op(a, b), op->type); // clang-format off - if (EVAL_IN_LAMBDA - (rewrite(x && true, a) || - rewrite(x && false, b) || - rewrite(x && x, a) || - - rewrite((x && y) && x, a) || - rewrite(x && (x && y), b) || - rewrite((x && y) && y, a) || - rewrite(y && (x && y), b) || - - rewrite(((x && y) && z) && x, a) || - rewrite(x && ((x && y) && z), b) || - rewrite((z && (x && y)) && x, a) || - rewrite(x && (z && (x && y)), b) || - rewrite(((x && y) && z) && y, a) || - rewrite(y && ((x && y) && z), b) || - rewrite((z && (x && y)) && y, a) || - rewrite(y && (z && (x && y)), b) || - - rewrite((x || y) && x, b) || - rewrite(x && (x || y), a) || - rewrite((x || y) && y, b) || - rewrite(y && (x || y), a) || + // Cases that fold to a constant + if (EVAL_IN_LAMBDA + (rewrite(x && false, false) || rewrite(x != y && x == y, false) || rewrite(x != y && y == x, false) || rewrite((z && x != y) && x == y, false) || @@ -57,7 +41,6 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { rewrite(!x && x, false) || rewrite(y <= x && x < y, false) || rewrite(y < x && x < y, false) || - rewrite(x != c0 && x == c1, b, c0 != c1) || rewrite(x == c0 && x == c1, false, c0 != c1) || // Note: In the predicate below, if undefined overflow // occurs, the predicate counts as false. If well-defined @@ -69,7 +52,36 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { rewrite(x <= c1 && c0 < x, false, c1 <= c0) || rewrite(c0 <= x && x < c1, false, c1 <= c0) || rewrite(c0 <= x && x <= c1, false, c1 < c0) || - rewrite(x <= c1 && c0 <= x, false, c1 < c0) || + rewrite(x <= c1 && c0 <= x, false, c1 < c0))) { + set_expr_info_to_constant(info, false); + return rewrite.result; + } + + // Cases that fold to one of the args + if (EVAL_IN_LAMBDA + (rewrite(x && true, a) || + rewrite(x && x, a) || + + rewrite((x && y) && x, a) || + rewrite(x && (x && y), b) || + rewrite((x && y) && y, a) || + rewrite(y && (x && y), b) || + + rewrite(((x && y) && z) && x, a) || + rewrite(x && ((x && y) && z), b) || + rewrite((z && (x && y)) && x, a) || + rewrite(x && (z && (x && y)), b) || + rewrite(((x && y) && z) && y, a) || + rewrite(y && ((x && y) && z), b) || + rewrite((z && (x && y)) && y, a) || + rewrite(y && (z && (x && y)), b) || + + rewrite((x || y) && x, b) || + rewrite(x && (x || y), a) || + rewrite((x || y) && y, b) || + rewrite(y && (x || y), a) || + + rewrite(x != c0 && x == c1, b, c0 != c1) || rewrite(c0 < x && c1 < x, fold(max(c0, c1)) < x) || rewrite(c0 <= x && c1 <= x, fold(max(c0, c1)) <= x) || rewrite(x < c0 && x < c1, x < fold(min(c0, c1))) || diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index 0d5d9a5ffdda..d00ebf029843 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -55,6 +55,10 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { found_buffer_reference(op->name, op->args.size()); } + if (info) { + info->cast_to(op->type); + } + if (op->is_intrinsic(Call::unreachable)) { in_unreachable = true; return op; @@ -90,7 +94,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { const uint64_t mask = std::numeric_limits::max() >> (64 - bits); u &= mask; static_assert(sizeof(unsigned long long) >= sizeof(uint64_t), ""); - int r = 0; + int64_t r = 0; if (op->is_intrinsic(Call::popcount)) { // popcount *is* well-defined for ua = 0 r = popcount64(u); @@ -101,7 +105,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // ctz64() is undefined for 0, but Halide's count_trailing_zeros defines clz(0) = bits r = u == 0 ? bits : (ctz64(u)); } - return make_const(op->type, r); + return make_const(op->type, r, info); } if (a.same_as(op->args[0])) { @@ -150,7 +154,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return make_signed_integer_overflow(t); } if (a.type().is_uint() || *ub < ((uint64_t)t.bits() - 1)) { - b = make_const(t, ((int64_t)1LL) << *ub); + b = make_const(t, ((int64_t)1) << *ub, nullptr); if (result_op == Call::get_intrinsic_name(Call::shift_left)) { return mutate(Mul::make(a, b), info); } else { @@ -161,9 +165,9 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // (-32768 >> (t.bits() - 1)) propagates the sign bit, making decomposition // into mul or div problematic, so just special-case them here. if (result_op == Call::get_intrinsic_name(Call::shift_left)) { - return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << *ub), make_zero(t)), info); + return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1) << *ub, nullptr), make_zero(t)), info); } else { - return mutate(select(a < 0, make_const(t, -1), make_zero(t)), info); + return mutate(select(a < 0, make_const(t, (int64_t)(-1), nullptr), make_zero(t)), info); } } } @@ -186,89 +190,130 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic); } } else if (op->is_intrinsic(Call::bitwise_and)) { - Expr a = mutate(op->args[0], nullptr); - Expr b = mutate(op->args[1], nullptr); + ExprInfo a_info, b_info; + Expr a = mutate(op->args[0], &a_info); + Expr b = mutate(op->args[1], &b_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { return mutate(unbroadcast, info); } + if (info && (op->type.is_int() || op->type.is_uint())) { + auto bits_known = a_info.to_bits_known(op->type) & b_info.to_bits_known(op->type); + info->from_bits_known(bits_known, op->type); + if (bits_known.all_bits_known()) { + // All bits are known, so this must be a constant + return make_const(op->type, bits_known.value, nullptr); + } + } + auto ia = as_const_int(a), ib = as_const_int(b); auto ua = as_const_uint(a), ub = as_const_uint(b); if (ia && ib) { - return make_const(op->type, *ia & *ib); + return make_const(op->type, *ia & *ib, info); } else if (ua && ub) { - return make_const(op->type, *ua & *ub); + return make_const(op->type, *ua & *ub, info); } else if (ib && !b.type().is_max(*ib) && is_const_power_of_two_integer(*ib + 1)) { - return Mod::make(a, make_const(a.type(), *ib + 1)); + return Mod::make(a, make_const(a.type(), *ib + 1, nullptr)); } else if (ub && b.type().is_max(*ub)) { return a; } else if (ib && *ib == -1) { return a; } else if (ub && is_const_power_of_two_integer(*ub + 1)) { - return Mod::make(a, make_const(a.type(), *ub + 1)); + return Mod::make(a, make_const(a.type(), *ub + 1, nullptr)); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { return a & b; } } else if (op->is_intrinsic(Call::bitwise_or)) { - Expr a = mutate(op->args[0], nullptr); - Expr b = mutate(op->args[1], nullptr); + ExprInfo a_info, b_info; + Expr a = mutate(op->args[0], &a_info); + Expr b = mutate(op->args[1], &b_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { return mutate(unbroadcast, info); } + if (info && (op->type.is_int() || op->type.is_uint())) { + auto bits_known = a_info.to_bits_known(op->type) | b_info.to_bits_known(op->type); + info->from_bits_known(bits_known, op->type); + if (bits_known.all_bits_known()) { + return make_const(op->type, bits_known.value, nullptr); + } + } + auto ia = as_const_int(a), ib = as_const_int(b); auto ua = as_const_uint(a), ub = as_const_uint(b); if (ia && ib) { - return make_const(op->type, *ia | *ib); + return make_const(op->type, *ia | *ib, info); } else if (ua && ub) { - return make_const(op->type, *ua | *ub); + return make_const(op->type, *ua | *ub, info); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { return a | b; } } else if (op->is_intrinsic(Call::bitwise_not)) { - Expr a = mutate(op->args[0], nullptr); + ExprInfo a_info; + Expr a = mutate(op->args[0], &a_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); if (unbroadcast.defined()) { return mutate(unbroadcast, info); } + if (info && (op->type.is_int() || op->type.is_uint())) { + // We could compute bits known here, but for the purpose of bounds + // and alignment, it's more precise to treat ~x as an all-ones bit + // pattern minus x. We get more information that way than just + // counting the leading zeros or ones. + Expr e = mutate(make_const(op->type, (int64_t)(-1), nullptr) - op->args[0], info); + // If the result of this happens to be a constant, we may as well + // return it. This is redundant with the constant folding below, but + // the constant folding below still needs to happen when info is + // nullptr. + if (info->bounds.is_single_point()) { + return e; + } + } + if (auto ia = as_const_int(a)) { - return make_const(op->type, ~(*ia)); + return make_const(op->type, ~(*ia), info); } else if (auto ua = as_const_uint(a)) { - return make_const(op->type, ~(*ua)); + return make_const(op->type, ~(*ua), info); } else if (a.same_as(op->args[0])) { return op; } else { return ~a; } } else if (op->is_intrinsic(Call::bitwise_xor)) { - Expr a = mutate(op->args[0], nullptr); - Expr b = mutate(op->args[1], nullptr); + ExprInfo a_info, b_info; + Expr a = mutate(op->args[0], &a_info); + Expr b = mutate(op->args[1], &b_info); Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); if (unbroadcast.defined()) { return mutate(unbroadcast, info); } + if (info && (op->type.is_int() || op->type.is_uint())) { + auto bits_known = a_info.to_bits_known(op->type) ^ b_info.to_bits_known(op->type); + info->from_bits_known(bits_known, op->type); + } + auto ia = as_const_int(a), ib = as_const_int(b); auto ua = as_const_uint(a), ub = as_const_uint(b); if (ia && ib) { - return make_const(op->type, *ia ^ *ib); + return make_const(op->type, *ia ^ *ib, info); } else if (ua && ub) { - return make_const(op->type, *ua ^ *ub); + return make_const(op->type, *ua ^ *ub, info); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { @@ -294,7 +339,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { if (*ia < 0 && !(Int(64).is_min(*ia))) { *ia = -(*ia); } - return make_const(op->type, *ia); + return make_const(op->type, *ia, info); } else if (ta.is_uint()) { // abs(uint) is a no-op. return a; @@ -302,7 +347,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { if (*fa < 0) { *fa = -(*fa); } - return make_const(a.type(), *fa); + return make_const(a.type(), *fa, info); } else if (a.type().is_int() && a_info.bounds >= 0) { return cast(op->type, a); } else if (a.type().is_int() && a_info.bounds <= 0) { @@ -334,13 +379,13 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // Note that absd(int, int) always produces a uint result internal_assert(op->type.is_uint()); const uint64_t d = *ia > *ib ? (uint64_t)(*ia - *ib) : (uint64_t)(*ib - *ia); - return make_const(op->type, d); + return make_const(op->type, d, info); } else if (ta.is_uint() && ua && ub) { const uint64_t d = *ua > *ub ? *ua - *ub : *ub - *ua; - return make_const(op->type, d); + return make_const(op->type, d, info); } else if (fa && fb) { const double d = *fa > *fb ? *fa - *fb : *fb - *fa; - return make_const(op->type, d); + return make_const(op->type, d, info); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { @@ -692,7 +737,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { Expr arg = mutate(op->args[0], nullptr); if (auto f = as_const_float(arg)) { auto fn = it->second; - return make_const(arg.type(), fn(*f)); + return make_const(arg.type(), fn(*f), info); } else if (arg.same_as(op->args[0])) { return op; } else { @@ -724,7 +769,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { const Call *call = arg.as(); if (auto f = as_const_float(arg)) { auto fn = it->second; - return make_const(arg.type(), fn(*f)); + return make_const(arg.type(), fn(*f), info); } else if (call && (call->call_type == Call::PureExtern || call->call_type == Call::PureIntrinsic) && (it = pure_externs_truncation.find(call->name)) != pure_externs_truncation.end()) { // For any combination of these integer-valued functions, we can @@ -756,7 +801,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { auto f1 = as_const_float(arg1); if (f0 && f1) { auto fn = it->second; - return make_const(arg0.type(), fn(*f0, *f1)); + return make_const(arg0.type(), fn(*f0, *f1), info); } else if (!arg0.same_as(op->args[0]) || !arg1.same_as(op->args[1])) { return Call::make(op->type, op->name, {arg0, arg1}, op->call_type); } else { diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index ae08ea3944fd..668b85425d63 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -39,62 +39,40 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { (f = as_const_float(value)) && std::isfinite(*f)) { // float -> int - // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(*f)), info); + return make_const(op->type, safe_numeric_cast(*f), info); } else if (op->type.is_uint() && (f = as_const_float(value)) && std::isfinite(*f)) { // float -> uint - // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(*f)), info); + return make_const(op->type, safe_numeric_cast(*f), info); } else if (op->type.is_float() && (f = as_const_float(value))) { // float -> float - return make_const(op->type, *f); + return make_const(op->type, *f, info); } else if (op->type.is_int() && (i = as_const_int(value))) { // int -> int - // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, *i), info); + return make_const(op->type, *i, info); } else if (op->type.is_uint() && (i = as_const_int(value))) { // int -> uint - return make_const(op->type, safe_numeric_cast(*i)); + return make_const(op->type, safe_numeric_cast(*i), info); } else if (op->type.is_float() && (i = as_const_int(value))) { // int -> float - return mutate(make_const(op->type, safe_numeric_cast(*i)), info); + return make_const(op->type, safe_numeric_cast(*i), info); } else if (op->type.is_int() && - (u = as_const_uint(value)) && - op->type.bits() < value.type().bits()) { - // uint -> int narrowing - // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(*u)), info); - } else if (op->type.is_int() && - (u = as_const_uint(value)) && - op->type.bits() == value.type().bits()) { - // uint -> int reinterpret - // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(*u)), info); - } else if (op->type.is_int() && - (u = as_const_uint(value)) && - op->type.bits() > value.type().bits()) { - // uint -> int widening - if (op->type.can_represent(*u) || op->type.bits() < 32) { - // If the type can represent the value or overflow is well-defined. - // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(*u)), info); - } else { - return make_signed_integer_overflow(op->type); - } + (u = as_const_uint(value))) { + // uint -> int. + return make_const(op->type, safe_numeric_cast(*u), info); } else if (op->type.is_uint() && (u = as_const_uint(value))) { // uint -> uint - return mutate(make_const(op->type, *u), info); + return make_const(op->type, *u, info); } else if (op->type.is_float() && (u = as_const_uint(value))) { // uint -> float - return make_const(op->type, safe_numeric_cast(*u)); + return make_const(op->type, safe_numeric_cast(*u), info); } else if (cast && op->type.code() == cast->type.code() && op->type.bits() < cast->type.bits()) { diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index 92487eddecc2..7127e32a5183 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -24,7 +24,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { // also cases with a bounded denominator (e.g. [5, 7]/[4, 5] = 1). if (info->bounds.is_single_point()) { if (op->type.can_represent(info->bounds.min)) { - return make_const(op->type, info->bounds.min); + return make_const(op->type, info->bounds.min, nullptr); } else { // Even though this is 'no-overflow-int', if the result // we calculate can't fit into the destination type, diff --git a/src/Simplify_EQ.cpp b/src/Simplify_EQ.cpp index 97c32814e03d..212b01e0a1ba 100644 --- a/src/Simplify_EQ.cpp +++ b/src/Simplify_EQ.cpp @@ -4,10 +4,14 @@ namespace Halide { namespace Internal { Expr Simplify::visit(const EQ *op, ExprInfo *info) { + if (info) { + info->cast_to(op->type); + } + if (truths.count(op)) { - return const_true(op->type.lanes()); + return const_true(op->type.lanes(), info); } else if (falsehoods.count(op)) { - return const_false(op->type.lanes()); + return const_false(op->type.lanes(), info); } if (!may_simplify(op->a.type())) { @@ -26,14 +30,13 @@ Expr Simplify::visit(const EQ *op, ExprInfo *info) { if (should_commute(a, b)) { std::swap(a, b); } - const int lanes = op->type.lanes(); auto rewrite = IRMatcher::rewriter(IRMatcher::eq(a, b), op->type); if (rewrite(x == 1, x)) { return rewrite.result; } else if (rewrite(x == 0, !x)) { return mutate(rewrite.result, info); - } else if (rewrite(x == x, const_true(lanes))) { - return rewrite.result; + } else if (rewrite(x == x, true)) { + return const_true(op->type.lanes(), info); } else if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { @@ -47,17 +50,17 @@ Expr Simplify::visit(const EQ *op, ExprInfo *info) { // If the delta is 0, then it's just x == x if (is_const_zero(delta)) { - return const_true(lanes); + return const_true(lanes, info); } // Attempt to disprove using bounds analysis if (!delta_info.bounds.contains(0)) { - return const_false(lanes); + return const_false(lanes, info); } // Attempt to disprove using modulus remainder analysis if (delta_info.alignment.remainder != 0) { - return const_false(lanes); + return const_false(lanes, info); } auto rewrite = IRMatcher::rewriter(IRMatcher::eq(delta, 0), op->type, delta.type()); diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 02f19ae13a6a..30c960e4313e 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -47,6 +47,7 @@ Expr Simplify::visit(const Broadcast *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::broadcast(value, lanes), op->type); if (rewrite(broadcast(broadcast(x, c0), lanes), broadcast(x, c0 * lanes)) || + rewrite(broadcast(IRMatcher::Overflow(), lanes), IRMatcher::Overflow()) || false) { return mutate(rewrite.result, info); } @@ -215,7 +216,7 @@ Expr Simplify::visit(const Variable *op, ExprInfo *info) { *info = *b; } if (b->bounds.is_single_point()) { - return make_const(op->type, b->bounds.min); + return make_const(op->type, b->bounds.min, nullptr); } } else if (info && !no_overflow_int(op->type)) { info->bounds = ConstantInterval::bounds_of_type(op->type); @@ -340,7 +341,7 @@ Expr Simplify::visit(const Load *op, ExprInfo *info) { Expr new_index = b_index->value; int new_lanes = new_index.type().lanes(); Expr load = Load::make(op->type.with_lanes(new_lanes), op->name, b_index->value, - op->image, op->param, const_true(new_lanes), align); + op->image, op->param, const_true(new_lanes, nullptr), align); return Broadcast::make(load, b_index->lanes); } else if (s_index && is_const_one(predicate) && @@ -351,7 +352,7 @@ Expr Simplify::visit(const Load *op, ExprInfo *info) { for (const Expr &new_index : s_index->vectors) { int new_lanes = new_index.type().lanes(); Expr load = Load::make(op->type.with_lanes(new_lanes), op->name, new_index, - op->image, op->param, const_true(new_lanes), ModulusRemainder{}); + op->image, op->param, const_true(new_lanes, nullptr), ModulusRemainder{}); loaded_vecs.emplace_back(std::move(load)); } return Shuffle::make(loaded_vecs, s_index->indices); diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 851b5d05c810..2966c06ef3d5 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -80,6 +80,24 @@ class Simplify : public VariadicVisitor { } } + uint64_t largest_power_of_two_factor(uint64_t x) const { + // Consider the bits of x from MSB to LSB. Say there are three + // trailing zeros, and the four high bits are unknown: + // a b c d 1 0 0 0 + // The largest power of two factor of a number is the trailing bits + // up to and including the first 1. In this example that's 1000 + // (i.e. 8). + // Negating is flipping the bits and adding one. First we flip: + // ~a ~b ~c ~d 0 1 1 1 + // Then we add one: + // ~a ~b ~c ~d 1 0 0 0 + // If we bitwise and this with the original, the unknown bits cancel + // out, and we get left with just the largest power of two + // factor. If we want a mask of the trailing zeros instead, we can + // just subtract one. + return x & -x; + } + void cast_to(Type t) { if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) { return; @@ -96,10 +114,8 @@ class Simplify : public VariadicVisitor { // representable as any 64-bit integer type, so there's no // wraparound. if (alignment.modulus > 0) { - // This masks off all bits except for the lowest set one, - // giving the largest power-of-two factor of a number. - alignment.modulus &= -alignment.modulus; - alignment.remainder = mod_imp(alignment.remainder, alignment.modulus); + alignment.modulus = largest_power_of_two_factor(alignment.modulus); + alignment.remainder &= alignment.modulus - 1; } } else { // A narrowing integer cast that could possibly overflow adds @@ -125,15 +141,144 @@ class Simplify : public VariadicVisitor { alignment = ModulusRemainder::intersect(alignment, other.alignment); trim_bounds_using_alignment(); } + + // An alternative representation for information about integers is that + // certain bits have known values in the 2s complement + // representation. This is a useful form for analyzing bitwise ops, so + // we provide conversions to and from that representation. For narrow + // types, this represent what the bits would be if they were sign or + // zero-extended to 64 bits, so for uints the high bits are known to be + // zero, and for ints it depends on whether or not we knew the high bit + // to begin with. + struct BitsKnown { + // A mask which is 1 where we know the value of that bit + uint64_t mask; + // The actual value of the known bits + uint64_t value; + + uint64_t known_zeros() const { + return mask & ~value; + } + + uint64_t known_ones() const { + return mask & value; + } + + bool all_bits_known() const { + return mask == (uint64_t)(-1); + } + + BitsKnown operator&(const BitsKnown &other) const { + // Where either has known zeros, we have known zeros in the result + uint64_t zeros = known_zeros() | other.known_zeros(); + // Where both have a known one, we have a known one in the result + uint64_t ones = known_ones() & other.known_ones(); + return {zeros | ones, ones}; + } + + BitsKnown operator|(const BitsKnown &other) const { + // Where either has known ones, we have known ones in the result + uint64_t ones = known_ones() | other.known_ones(); + // Where both have a known zero, we have a known zero in the result. + uint64_t zeros = known_zeros() & other.known_zeros(); + return {zeros | ones, ones}; + } + + BitsKnown operator^(const BitsKnown &other) const { + // Unlike & and |, we need to know both bits to know anything. + uint64_t new_mask = mask & other.mask; + return {new_mask, (value ^ other.value) & new_mask}; + } + }; + + BitsKnown to_bits_known(const Type &type) const; + void from_bits_known(BitsKnown known, const Type &type); }; HALIDE_ALWAYS_INLINE - void clear_expr_info(ExprInfo *b) { - if (b) { - *b = ExprInfo{}; + void clear_expr_info(ExprInfo *info) { + if (info) { + *info = ExprInfo{}; + } + } + + void set_expr_info_to_constant(ExprInfo *info, int64_t c) const { + if (info) { + info->bounds = ConstantInterval::single_point(c); + info->alignment = ModulusRemainder{0, c}; } } + int64_t normalize_constant(const Type &t, int64_t c) { + // If this is supposed to be an int32, but the constant is not + // representable as an int32, we have a problem, because the Halide + // simplifier is unsound with respect to int32 overflow (see + // https://github.com/halide/Halide/issues/3245). + + // It's tempting to just say we return a signed_integer_overflow, for + // which we know nothing, but if we're in this function we're making a + // constant, so we clearly decided not to do that in the caller. Is this + // a bug in the caller? No, this intentionally happens when + // constant-folding narrowing casts, and changing that behavior to + // return signed_integer_overflow breaks a bunch of real code, because + // unfortunately that's how people express wrapping casts to int32. We + // could return an ExprInfo that says "I know nothing", but we're also + // returning a constant Expr, so the next mutation is just going to + // infer everything there is to infer about a constant. The best we can + // do at this point is just wrap to the right number of bits. + if (t.is_int()) { + c <<= (64 - t.bits()); + c >>= (64 - t.bits()); + } else if (t.is_uint()) { + // For uints, normalization is considerably less problematic + c <<= (64 - t.bits()); + c = (int64_t)(((uint64_t)c >> (64 - t.bits()))); + } + return c; + } + + // We never want to return make_const anything in the simplifier without + // also setting the ExprInfo, so shadow the global make_const. + Expr make_const(const Type &t, int64_t c, ExprInfo *info) { + c = normalize_constant(t, c); + set_expr_info_to_constant(info, c); + return Halide::Internal::make_const(t, c); + } + + Expr make_const(const Type &t, uint64_t c, ExprInfo *info) { + c = normalize_constant(t, c); + + if ((int64_t)c >= 0) { + set_expr_info_to_constant(info, (int64_t)c); + } else if (info) { + // If it's not representable as an int64, we can't express + // everything we know about it in ExprInfo. + // We can say that it's big: + info->bounds = ConstantInterval::bounded_below(INT64_MAX); + // And we can say what we know about the bottom 62 bits (2^62 is the + // largest power of two we can represent as an int64_t): + int64_t modulus = (int64_t)1 << 62; + info->alignment = {modulus, (int64_t)c & (modulus - 1)}; + } + return Halide::Internal::make_const(t, c); + } + + HALIDE_ALWAYS_INLINE + Expr make_const(const Type &t, double c, ExprInfo *info) { + // We don't currently track information about floats + return Halide::Internal::make_const(t, c); + } + + HALIDE_ALWAYS_INLINE + Expr const_false(int lanes, ExprInfo *info) { + return make_const(UInt(1, lanes), (int64_t)0, info); + } + + HALIDE_ALWAYS_INLINE + Expr const_true(int lanes, ExprInfo *info) { + return make_const(UInt(1, lanes), (int64_t)1, info); + } + #if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS) int debug_indent = 0; #endif @@ -154,13 +299,23 @@ class Simplify : public VariadicVisitor { debug(1) << spaces << "Bounds: " << b->bounds << " " << b->alignment << "\n"; if (auto i = as_const_int(new_e)) { - internal_assert(b->bounds.contains(*i)) << e << "\n" - << new_e << "\n" - << b->bounds; + internal_assert(b->bounds.contains(*i)) + << e << "\n" + << new_e << "\n" + << b->bounds; } else if (auto i = as_const_uint(new_e)) { - internal_assert(b->bounds.contains(*i)) << e << "\n" - << new_e << "\n" - << b->bounds; + internal_assert(b->bounds.contains(*i)) + << e << "\n" + << new_e << "\n" + << b->bounds; + } + if (new_e.type().is_uint() && + new_e.type().bits() < 64 && + !is_signed_integer_overflow(new_e)) { + internal_assert(b->bounds.min_defined && b->bounds.min >= 0) + << e << "\n" + << new_e << "\n" + << b->bounds; } } } diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index c9ac45c349d7..9e50412f325c 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -8,22 +8,26 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { Expr a = mutate(op->a, &a_info); Expr b = mutate(op->b, &b_info); + if (info) { + info->cast_to(op->type); + } + const int lanes = op->type.lanes(); Type ty = a.type(); if (truths.count(op)) { - return const_true(lanes); + return const_true(lanes, info); } else if (falsehoods.count(op)) { - return const_false(lanes); + return const_false(lanes, info); } if (may_simplify(ty)) { // Prove or disprove using bounds analysis if (a_info.bounds < b_info.bounds) { - return const_true(lanes); + return const_true(lanes, info); } else if (a_info.bounds >= b_info.bounds) { - return const_false(lanes); + return const_false(lanes, info); } int lanes = op->type.lanes(); diff --git a/src/Simplify_Mod.cpp b/src/Simplify_Mod.cpp index cc2fe2109a27..5f8ec085776b 100644 --- a/src/Simplify_Mod.cpp +++ b/src/Simplify_Mod.cpp @@ -29,7 +29,7 @@ Expr Simplify::visit(const Mod *op, ExprInfo *info) { } if (mod_info.bounds.is_single_point()) { - return make_const(op->type, mod_info.bounds.min); + return make_const(op->type, mod_info.bounds.min, nullptr); } int lanes = op->type.lanes(); diff --git a/src/Simplify_Not.cpp b/src/Simplify_Not.cpp index 1c855ae82dd9..a5203c2b7300 100644 --- a/src/Simplify_Not.cpp +++ b/src/Simplify_Not.cpp @@ -4,7 +4,14 @@ namespace Halide { namespace Internal { Expr Simplify::visit(const Not *op, ExprInfo *info) { - Expr a = mutate(op->a, nullptr); + ExprInfo a_info; + Expr a = mutate(op->a, &a_info); + + if (info) { + info->bounds = ConstantInterval::single_point(1) - a_info.bounds; + info->alignment = ModulusRemainder{0, 1} - a_info.alignment; + info->cast_to(op->type); + } auto rewrite = IRMatcher::rewriter(IRMatcher::not_op(a), op->type); diff --git a/src/Simplify_Or.cpp b/src/Simplify_Or.cpp index 083af6d5bc88..a45615ddf210 100644 --- a/src/Simplify_Or.cpp +++ b/src/Simplify_Or.cpp @@ -5,7 +5,7 @@ namespace Internal { Expr Simplify::visit(const Or *op, ExprInfo *info) { if (truths.count(op)) { - return const_true(op->type.lanes()); + return const_true(op->type.lanes(), info); } Expr a = mutate(op->a, nullptr); @@ -15,12 +15,43 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { std::swap(a, b); } + if (info) { + info->cast_to(op->type); + } + auto rewrite = IRMatcher::rewriter(IRMatcher::or_op(a, b), op->type); // clang-format off + + // Cases that fold to a constant if (EVAL_IN_LAMBDA - (rewrite(x || true, b) || - rewrite(x || false, a) || + (rewrite(x || true, true) || + rewrite(x != y || x == y, true) || + rewrite(x != y || y == x, true) || + rewrite((z || x != y) || x == y, true) || + rewrite((z || x != y) || y == x, true) || + rewrite((x != y || z) || x == y, true) || + rewrite((x != y || z) || y == x, true) || + rewrite((z || x == y) || x != y, true) || + rewrite((z || x == y) || y != x, true) || + rewrite((x == y || z) || x != y, true) || + rewrite((x == y || z) || y != x, true) || + rewrite(x || !x, true) || + rewrite(!x || x, true) || + rewrite(y <= x || x < y, true) || + rewrite(x <= c0 || c1 <= x, true, !is_float(x) && c1 <= c0 + 1) || + rewrite(c1 <= x || x <= c0, true, !is_float(x) && c1 <= c0 + 1) || + rewrite(x <= c0 || c1 < x, true, c1 <= c0) || + rewrite(c1 <= x || x < c0, true, c1 <= c0) || + rewrite(x < c0 || c1 < x, true, c1 < c0) || + rewrite(c1 < x || x < c0, true, c1 < c0))) { + set_expr_info_to_constant(info, true); + return rewrite.result; + } + + // Cases that fold to one of the args + if (EVAL_IN_LAMBDA + (rewrite(x || false, a) || rewrite(x || x, a) || rewrite((x || y) || x, a) || @@ -42,26 +73,7 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { rewrite((x && y) || y, b) || rewrite(y || (x && y), a) || - rewrite(x != y || x == y, true) || - rewrite(x != y || y == x, true) || - rewrite((z || x != y) || x == y, true) || - rewrite((z || x != y) || y == x, true) || - rewrite((x != y || z) || x == y, true) || - rewrite((x != y || z) || y == x, true) || - rewrite((z || x == y) || x != y, true) || - rewrite((z || x == y) || y != x, true) || - rewrite((x == y || z) || x != y, true) || - rewrite((x == y || z) || y != x, true) || - rewrite(x || !x, true) || - rewrite(!x || x, true) || - rewrite(y <= x || x < y, true) || rewrite(x != c0 || x == c1, a, c0 != c1) || - rewrite(x <= c0 || c1 <= x, true, !is_float(x) && c1 <= c0 + 1) || - rewrite(c1 <= x || x <= c0, true, !is_float(x) && c1 <= c0 + 1) || - rewrite(x <= c0 || c1 < x, true, c1 <= c0) || - rewrite(c1 <= x || x < c0, true, c1 <= c0) || - rewrite(x < c0 || c1 < x, true, c1 < c0) || - rewrite(c1 < x || x < c0, true, c1 < c0) || rewrite(c0 < x || c1 < x, fold(min(c0, c1)) < x) || rewrite(c0 <= x || c1 <= x, fold(min(c0, c1)) <= x) || rewrite(x < c0 || x < c1, x < fold(max(c0, c1))) || @@ -70,6 +82,7 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { } // clang-format on + // Cases that need remutation if (rewrite(broadcast(x, c0) || broadcast(y, c0), broadcast(x || y, c0)) || rewrite((x && (y || z)) || y, (x && z) || y) || rewrite((x && (z || y)) || y, (x && z) || y) || diff --git a/src/Simplify_Reinterpret.cpp b/src/Simplify_Reinterpret.cpp index 259b7fb4f486..30aff3f96919 100644 --- a/src/Simplify_Reinterpret.cpp +++ b/src/Simplify_Reinterpret.cpp @@ -6,15 +6,22 @@ namespace Internal { Expr Simplify::visit(const Reinterpret *op, ExprInfo *info) { Expr a = mutate(op->value, nullptr); + if (info) { + // We don't track bounds and such through reinterprets, but we do know + // things about the result, just based on its type, e.g. if we're + // reinterpreting to a uint8, it's <= 255. + info->cast_to(op->type); + } + bool vector = op->type.is_vector() || a.type().is_vector(); if (op->type == a.type()) { return a; } else if (auto ia = as_const_int(a); ia && op->type.is_uint() && !vector) { // int -> uint - return make_const(op->type, reinterpret_bits(*ia)); + return make_const(op->type, reinterpret_bits(*ia), info); } else if (auto ua = as_const_uint(a); ua && op->type.is_int() && !vector) { // uint -> int - return make_const(op->type, reinterpret_bits(*ua)); + return make_const(op->type, reinterpret_bits(*ua), info); } else if (const Reinterpret *as_r = a.as()) { // Fold double-reinterprets. return mutate(reinterpret(op->type, as_r->value), info); diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index 1e778c229b95..6dada3cf0dab 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -21,18 +21,25 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) { // clang-format off if (EVAL_IN_LAMBDA - (rewrite(select(IRMatcher::likely(true), x, y), x) || - rewrite(select(IRMatcher::likely(false), x, y), y) || - rewrite(select(IRMatcher::likely_if_innermost(true), x, y), x) || - rewrite(select(IRMatcher::likely_if_innermost(false), x, y), y) || - rewrite(select(1, x, y), x) || - rewrite(select(0, x, y), y) || - rewrite(select(x, y, y), y) || + (rewrite(select(IRMatcher::likely(true), x, y), true_value) || + rewrite(select(IRMatcher::likely(false), x, y), false_value) || + rewrite(select(IRMatcher::likely_if_innermost(true), x, y), true_value) || + rewrite(select(IRMatcher::likely_if_innermost(false), x, y), false_value) || + rewrite(select(1, x, y), true_value) || + rewrite(select(0, x, y), false_value) || + rewrite(select(x, y, y), false_value) || rewrite(select(x, likely(y), y), false_value) || rewrite(select(x, y, likely(y)), true_value) || rewrite(select(x, likely_if_innermost(y), y), false_value) || rewrite(select(x, y, likely_if_innermost(y)), true_value) || false)) { + if (info) { + if (rewrite.result.same_as(true_value)) { + *info = t_info; + } else if (rewrite.result.same_as(false_value)) { + *info = f_info; + } + } return rewrite.result; } // clang-format on diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 348289ab0c83..aecb4c6fc99a 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -78,7 +78,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { Expr shuffled_predicate; if (unpredicated) { - shuffled_predicate = const_true(t.lanes()); + shuffled_predicate = const_true(t.lanes(), nullptr); } else { shuffled_predicate = Shuffle::make(load_predicates, op->indices); shuffled_predicate = mutate(shuffled_predicate, nullptr); diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 9f0a5ace1158..cd2c440de6ba 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -352,7 +352,7 @@ Stmt Simplify::visit(const Store *op) { return Evaluate::make(0); } else if (scalar_pred && !is_const_one(scalar_pred->value)) { return IfThenElse::make(scalar_pred->value, - Store::make(op->name, value, index, op->param, const_true(value.type().lanes()), align)); + Store::make(op->name, value, index, op->param, const_true(value.type().lanes(), nullptr), align)); } else if (is_undef(value) || (load && load->name == op->name && equal(load->index, index))) { // foo[x] = foo[x] or foo[x] = undef is a no-op return Evaluate::make(0); diff --git a/src/runtime/HalideBuffer.h b/src/runtime/HalideBuffer.h index 9a97b79c3591..b5e613690702 100644 --- a/src/runtime/HalideBuffer.h +++ b/src/runtime/HalideBuffer.h @@ -971,7 +971,7 @@ class Buffer { /** Allocate a new image of the given size with a runtime * type. Only used when you do know what size you want but you - * don't know statically what type the elements are. Pass zeroes + * don't know statically what type the elements are. Pass zeros * to make a buffer suitable for bounds query calls. */ template::value>::type> @@ -990,7 +990,7 @@ class Buffer { } } - /** Allocate a new image of the given size. Pass zeroes to make a + /** Allocate a new image of the given size. Pass zeros to make a * buffer suitable for bounds query calls. */ // @{ diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 291af444cfd3..d4bd5278bfc9 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -12,6 +12,7 @@ tests(GROUPS correctness autodiff.cpp bad_likely.cpp bit_counting.cpp + bits_known.cpp bitwise_ops.cpp bool_compute_root_vectorize.cpp bool_predicate_cast.cpp diff --git a/test/correctness/bits_known.cpp b/test/correctness/bits_known.cpp new file mode 100644 index 000000000000..7be9a9d843bc --- /dev/null +++ b/test/correctness/bits_known.cpp @@ -0,0 +1,116 @@ +#include "Halide.h" + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + + Param i64("i64"); + Param i32("i32"); + Param i16("i16"); + Param u64("u64"); + Param u32("u32"); + Param u16("u16"); + Param u8("u8"); + + // A list of Exprs we should be able to prove true by analyzing the bitwise op we do + Expr exprs[] = { + // Manipulate or isolate the low bits + (i64 & 1) < 2, + (i64 & 1) >= 0, + (i64 | 1) % 2 == 1, + (i64 & 2) <= 2, + (i64 & 2) >= 0, + + (min(i32, -1) ^ (i32 & 255)) < 0, + + // The next is currently beyond us, because we'd have to carry expr + // information in the bits_known format through the modulus + // op. Currently just known the second-lowest-bit is 2 but nothing else + // doesn't give us an alignment or bounds. + // (i64 | 2) % 4 >= 2, + + (u64 & 1) < 2, + (u64 & 1) >= 0, + (u64 | 1) % 2 == 1, + (u64 & 2) <= 2, + (u64 & 2) >= 0, + // Beyond us for the same reason as above + // (u64 | 2) % 4 >= 2, + + // Manipulate or isolate the high bits, in various types, starting with + // two common idioms for aligning a value to a multiple of 16. + ((i32 & ~15) & 15) == 0, + ((i32 & ~15) % 16) == 0, + ((i32 & cast(u16 << 2)) | 5) % 8 == 5, + (i32 | 0x80000000) < 0, + cast(u32 & ~0x80000000) >= 0, + (cast(u16) & (cast(u16) << 16)) == 0, + + // Setting or unsetting bits makes a number larger or smaller, respectively + (i32 & cast(u16)) >= 0, + (i32 & cast(u16)) < 0x10000, + + // What happens when the known bits say a uint is too big to represent + // in our bounds? Not currently reachable, because the (intentional) + // overflow on the cast to uint causes ConstantInterval to just drop all + // information. + // (cast(i64 | -2)) > u32 + + // Flipping the bits of an int flips it without overflow. I.e. for a + // uint8, ~x is 255 - x. This gives us bounds information. + (~clamp(u8, 3, 5)) >= 255 - 5, + (~clamp(u8, 3, 5)) <= 255 - 3, + + // If we knew the trailing bits before, we still know them after + (~(i32 * 16)) % 16 == 15, + + }; + + // Check we're not inferring *too* much, with variants of the above that + // shouldn't be provable one way or the other. + Expr negative_exprs[] = { + (i64 & 3) < 2, + (i64 & 3) >= 1, + (i64 | 1) % 4 == 1, + (i64 & 3) <= 2, + (i64 & 3) >= 1, + + (max(u32, 1000) ^ (u64 & 255)) >= 1000, + + (u64 & 3) < 2, + (u64 & 3) >= 1, + (u64 | 1) % 4 == 1, + (u64 & 3) <= 2, + (u64 & 2) >= 1, + + ((i32 & ~15) & 31) == 0, + ((i32 & ~15) % 32) == 0, + ((i32 & cast(u16 << 1)) | 5) % 8 == 5, + (i32 | 0x80000000) < -1, + cast(u32 & ~0x80000000) >= 0, + (cast(u16) & (cast(u16) << 15)) == 0, + + (i32 & cast(u16)) >= 1, + (i32 & cast(u16)) < 0xffff, + + (~clamp(u8, 3, 5)) >= 255 - 4, + }; + + for (auto e : exprs) { + if (!can_prove(e)) { + std::cerr << "Failed to prove: " << e << "\n"; + return -1; + } + } + + for (auto e : negative_exprs) { + if (is_const(simplify(e))) { + std::cerr << "Should not have been able to prove or disprove: " << e << "\n"; + return -1; + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/extern_producer.cpp b/test/correctness/extern_producer.cpp index be6a47fd8fe4..a1b461493bdf 100644 --- a/test/correctness/extern_producer.cpp +++ b/test/correctness/extern_producer.cpp @@ -114,7 +114,7 @@ int main(int argc, char **argv) { Buffer output = sink.realize({100, 100}); - // Should be all zeroes. + // Should be all zeros. RDom r(output); unsigned int error = evaluate_may_gpu(sum(abs(output(r.x, r.y)))); if (error != 0) { @@ -142,7 +142,7 @@ int main(int argc, char **argv) { Buffer output_multi = sink_multi.realize({100, 100}); - // Should be all zeroes. + // Should be all zeros. RDom r(output_multi); unsigned int error_multi = evaluate(sum(abs(output_multi(r.x, r.y)))); if (error_multi != 0) { diff --git a/test/correctness/extern_stage_on_device.cpp b/test/correctness/extern_stage_on_device.cpp index b4c4a7ff64ae..aeff41d64c2d 100644 --- a/test/correctness/extern_stage_on_device.cpp +++ b/test/correctness/extern_stage_on_device.cpp @@ -76,7 +76,7 @@ int main(int argc, char **argv) { Buffer output = sink.realize({100, 100}); - // Should be all zeroes. + // Should be all zeros. RDom r(output); uint32_t error = evaluate_may_gpu(sum(abs(output(r.x, r.y)))); if (error != 0) { diff --git a/test/correctness/fuzz_simplify.cpp b/test/correctness/fuzz_simplify.cpp index 23987cb7cbd9..a9ed27aba7ea 100644 --- a/test/correctness/fuzz_simplify.cpp +++ b/test/correctness/fuzz_simplify.cpp @@ -119,6 +119,23 @@ Expr make_absd(Expr a, Expr b) { return cast(a.type(), absd(a, b)); } +Expr make_bitwise_or(Expr a, Expr b) { + return a | b; +} + +Expr make_bitwise_and(Expr a, Expr b) { + return a & b; +} + +Expr make_bitwise_xor(Expr a, Expr b) { + return a ^ b; +} + +// This just exists to make sure bitwise not gets used somewhere +Expr make_bitwise_nor(Expr a, Expr b) { + return ~a | ~b; +} + Expr random_expr(std::mt19937 &rng, Type t, int depth, bool overflow_undef) { if (t.is_int() && t.bits() == 32) { overflow_undef = true; @@ -193,6 +210,10 @@ Expr random_expr(std::mt19937 &rng, Type t, int depth, bool overflow_undef) { Div::make, Mod::make, make_absd, + make_bitwise_or, + make_bitwise_and, + make_bitwise_xor, + make_bitwise_nor, }; Expr a = random_expr(rng, t, depth, overflow_undef); diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index 22d3506d7859..d22489b16c6c 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -140,7 +140,7 @@ Expr random_expr(std::mt19937 &rng) { e = common_subexpression_elimination(e1); break; case 7: - switch (rng() % 19) { + switch (rng() % 20) { case 0: if (may_widen) { e = widening_add(e1, e2); @@ -214,6 +214,9 @@ Expr random_expr(std::mt19937 &rng) { case 18: e = rounding_shift_left(e1, e2); break; + case 19: + e = ~e1; + break; } } diff --git a/test/correctness/mul_div_mod.cpp b/test/correctness/mul_div_mod.cpp index 8eca8141bba2..0c51c61fbf21 100644 --- a/test/correctness/mul_div_mod.cpp +++ b/test/correctness/mul_div_mod.cpp @@ -357,7 +357,7 @@ bool div_mod(int vector_width, ScheduleVariant scheduling, const Target &target) Buffer b = init(t, 2, WIDTH, HEIGHT); // Filter the input values for the operation to be tested. - // Cannot divide by zero, so remove zeroes from b. + // Cannot divide by zero, so remove zeros from b. // Also, cannot divide the most negative number by -1. for (i = 0; i < WIDTH; i++) { for (j = 0; j < HEIGHT; j++) { @@ -462,7 +462,7 @@ bool f_mod() { Buffer out(WIDTH, HEIGHT); // Filter the input values for the operation to be tested. - // Cannot divide by zero, so remove zeroes from b. + // Cannot divide by zero, so remove zeros from b. for (i = 0; i < WIDTH; i++) { for (j = 0; j < HEIGHT; j++) { if (b(i, j) == 0.0) { diff --git a/test/generator/autograd_aottest.cpp b/test/generator/autograd_aottest.cpp index b90616964dc8..2ad98ebe74ff 100644 --- a/test/generator/autograd_aottest.cpp +++ b/test/generator/autograd_aottest.cpp @@ -77,7 +77,7 @@ int main(int argc, char **argv) { _grad_loss_output_lut_wrt_lut_indices Note that the outputs with "_dummy" prefixes are placeholder - outputs that are always filled with zeroes; in those cases, + outputs that are always filled with zeros; in those cases, there is no derivative for the output/input pairing, but we emit an output nevertheless so that the function signature is always mechanically predictable from the list of inputs and outputs.