Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounds and alignment analysis through bitwise ops #8574

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ ConstantInterval bounds_helper(const Expr &e,
ConstantInterval cq = recurse(op->args[2]);
ConstantInterval rounding_term = 1 << (cq - 1);
return (ca * cb + rounding_term) >> cq;
} else if (op->is_intrinsic(Call::bitwise_not)) {
// We can't do much with the other bitwise ops, but we can treat
// bitwise_not as an all-ones bit pattern minus the argument.
return recurse(make_const(e.type(), -1) - op->args[0]);
}
// If you add a new intrinsic here, also add it to the expression
// generator in test/correctness/lossless_cast.cpp
Expand Down
109 changes: 109 additions & 0 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,5 +496,114 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) {
return is_const_one(e);
}

Simplify::ExprInfo::BitsKnown Simplify::ExprInfo::to_bits_known(const Type &type) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this bitwise math is a little tricky to follow. Have you thrown this in an SMT solver? I think this should be verified

BitsKnown result = {0, 0};

if (!(type.is_int() || type.is_uint())) {
// Let's not claim we know anything about the bit patterns of
// non-integer types for now.
return result;
}

// Identify the largest power of two in the modulus to get some low bits
if (alignment.modulus) {
result.mask = largest_power_of_two_factor(alignment.modulus) - 1;
result.value = result.mask & alignment.remainder;
} else {
// This value is just a constant
result.mask = (uint64_t)(-1);
result.value = alignment.remainder;
return result;
}

// The bounds and the type tell us a bunch of high bits are zero or one
if (bounds >= 0) {
if (type.is_int()) {
// The sign bit and above are zero.
result.mask |= (uint64_t)(-1) << (type.bits() - 1);
} else if (type.bits() < 64) {
// Narrow uints are zero-extended.
result.mask |= (uint64_t)(-1) << type.bits();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this true regardless of the bounds?

}
if (bounds.max_defined) {
// It's positive and the max is representable as an int64, so at least
// one high bit is zero, but bounds.max isn't zero or we would have
// returned above.
result.mask |= (uint64_t)(-1) << (64 - clz64(bounds.max));
}
} else if (bounds < 0) {
// At least one high bit is one, but bounds.min isn't -1 or we would
// have returned above.
uint64_t high_bits = (uint64_t)(-1) << (type.bits() - 1);
if (bounds.min_defined) {
high_bits |= (uint64_t)(-1) << (64 - clz64(~bounds.min));
}
result.mask |= high_bits;
result.value |= high_bits;
}

return result;
}

void Simplify::ExprInfo::from_bits_known(Simplify::ExprInfo::BitsKnown known, const Type &type) {
// Normalize everything to 64-bits by sign- or zero-extending known bits for
// the type.
uint64_t missing_bits = 0;
if (type.bits() < 64) {
missing_bits = (uint64_t)(-1) << type.bits();
}
if (missing_bits) {
if (type.is_uint()) {
// For a uint the high bits are zero
known.mask |= missing_bits;
known.value &= ~missing_bits;
} else if (type.is_int()) {
// For an int we need to know the sign to know the high bits
bool sign_bit_known = (known.mask >> (type.bits() - 1)) & 1;
bool negative = (known.value >> (type.bits() - 1)) & 1;
if (!sign_bit_known) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cases need comments explaining each of them

known.mask &= ~missing_bits;
known.value &= ~missing_bits;
} else if (negative) {
known.mask |= missing_bits;
known.value |= missing_bits;
} else if (!negative) {
known.mask |= missing_bits;
known.value &= ~missing_bits;
}
}
}

// We can get the trailing one bits by adding one and taking the largest
// power of two factor. Note that this works out correctly when we know all
// the bits - the modulus comes out as zero, and the remainder is the entire
// number, which is how we represent constants in ModulusRemainder.
alignment.modulus = largest_power_of_two_factor(known.mask + 1);
alignment.remainder = known.value & (alignment.modulus - 1);

if ((int64_t)known.mask < 0) {
// We know some leading bits

// Set all unknown bits to zero
uint64_t min_val = known.value & known.mask;
// Set all unknown bits to one
uint64_t max_val = known.value | ~known.mask;

if (type.is_uint() && (int64_t)known.value < 0) {
// We know it's out of range at the top end for our ConstantInterval
// class. At the time of writing, to_bits_known can't produce this
// directly, and bits_known is never propagated through other
// operations, so this code is unreachable. Nonetheless we'll do the
// best job we can at representing this case in case this code
// becomes reachable in future.
bounds = ConstantInterval::bounded_below((1ULL << 63) - 1);
} else {
// In all other cases, the bounds are representable as an int64
// and don't span zero (because we know the high bit).
bounds = ConstantInterval{(int64_t)min_val, (int64_t)max_val};
}
}
}

} // namespace Internal
} // namespace Halide
49 changes: 42 additions & 7 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,23 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic);
}
} else if (op->is_intrinsic(Call::bitwise_and)) {
Expr a = mutate(op->args[0], nullptr);
Expr b = mutate(op->args[1], nullptr);
ExprInfo a_info, b_info;
Expr a = mutate(op->args[0], &a_info);
Expr b = mutate(op->args[1], &b_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
auto bits_known = a_info.to_bits_known(op->type) & b_info.to_bits_known(op->type);
info->from_bits_known(bits_known, op->type);
if (bits_known.mask == (uint64_t)-1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment saying something like "This is a constant". Or make ExprInfo::BitsKnown have a helper function like "std::optional<uint64_t> as_constant()`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and same below with bitwise_or

return make_const(op->type, bits_known.value);
}
}

auto ia = as_const_int(a), ib = as_const_int(b);
auto ua = as_const_uint(a), ub = as_const_uint(b);

Expand All @@ -218,14 +227,23 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return a & b;
}
} else if (op->is_intrinsic(Call::bitwise_or)) {
Expr a = mutate(op->args[0], nullptr);
Expr b = mutate(op->args[1], nullptr);
ExprInfo a_info, b_info;
Expr a = mutate(op->args[0], &a_info);
Expr b = mutate(op->args[1], &b_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
auto bits_known = a_info.to_bits_known(op->type) | b_info.to_bits_known(op->type);
info->from_bits_known(bits_known, op->type);
if (bits_known.mask == (uint64_t)-1) {
return make_const(op->type, bits_known.value);
}
}

auto ia = as_const_int(a), ib = as_const_int(b);
auto ua = as_const_uint(a), ub = as_const_uint(b);
if (ia && ib) {
Expand All @@ -238,13 +256,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return a | b;
}
} else if (op->is_intrinsic(Call::bitwise_not)) {
Expr a = mutate(op->args[0], nullptr);
ExprInfo a_info;
Expr a = mutate(op->args[0], &a_info);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the bits_known mask in the same way here that you do for and/or? negated constants are constants

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was, but then I changed it to not actually compute BitsKnown here.


Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
// For the purpose of bounds and alignment, ~x can be treated as an
// all-ones bit pattern minus x.
Expr e = mutate(make_const(op->type, -1) - op->args[0], info);
// If the result of this happens to be a constant, we can also just return it
if (info->bounds.is_single_point()) {
return e;
}
}

if (auto ia = as_const_int(a)) {
return make_const(op->type, ~(*ia));
} else if (auto ua = as_const_uint(a)) {
Expand All @@ -255,14 +284,20 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return ~a;
}
} else if (op->is_intrinsic(Call::bitwise_xor)) {
Expr a = mutate(op->args[0], nullptr);
Expr b = mutate(op->args[1], nullptr);
ExprInfo a_info, b_info;
Expr a = mutate(op->args[0], &a_info);
Expr b = mutate(op->args[1], &b_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
auto bits_known = a_info.to_bits_known(op->type) ^ b_info.to_bits_known(op->type);
info->from_bits_known(bits_known, op->type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, check for a constant?

}

auto ia = as_const_int(a), ib = as_const_int(b);
auto ua = as_const_uint(a), ub = as_const_uint(b);
if (ia && ib) {
Expand Down
70 changes: 66 additions & 4 deletions src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
}
}

uint64_t largest_power_of_two_factor(uint64_t x) const {
// Consider the bits of x from MSB to LSB. Say there are three
// trailing zeros, and the four high bits are unknown:
// a b c d 1 0 0 0
// The largest power of two factor of a number is the trailing bits
// up to and including the first 1. In this example that's 1000
// (i.e. 8).
// Negating is flipping the bits and adding one. First we flip:
// ~a ~b ~c ~d 0 1 1 1
// Then we add one:
// ~a ~b ~c ~d 1 0 0 0
// If we bitwise and this with the original, the unknown bits cancel
// out, and we get left with just the largest power of two
// factor. If we want a mask of the trailing zeros instead, we can
// just subtract one.
return x & -x;
}

void cast_to(Type t) {
if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) {
return;
Expand All @@ -96,10 +114,8 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
// representable as any 64-bit integer type, so there's no
// wraparound.
if (alignment.modulus > 0) {
// This masks off all bits except for the lowest set one,
// giving the largest power-of-two factor of a number.
alignment.modulus &= -alignment.modulus;
alignment.remainder = mod_imp(alignment.remainder, alignment.modulus);
alignment.modulus = largest_power_of_two_factor(alignment.modulus);
alignment.remainder &= alignment.modulus - 1;
}
} else {
// A narrowing integer cast that could possibly overflow adds
Expand All @@ -125,6 +141,52 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
alignment = ModulusRemainder::intersect(alignment, other.alignment);
trim_bounds_using_alignment();
}

// An alternative representation for information about integers is that
// certain bits have known values in the 2s complement
// representation. This is a useful form for analyzing bitwise ops, so
// we provide conversions to and from that representation. For narrow
// types, this represent what the bits would be if they were sign or
// zero-extended to 64 bits, so for uints the high bits are known to be
// zero, and for ints it depends on whether or not we knew the high bit
// to begin with.
struct BitsKnown {
// A mask which is 1 where we know the value of that bit
uint64_t mask;
// The actual value of the known bits
uint64_t value;

uint64_t known_zeros() const {
return mask & ~value;
}

uint64_t known_ones() const {
return mask & value;
}

BitsKnown operator&(const BitsKnown &other) const {
// Where either has known zeros, we have known zeros in the result
// Where both have a known one, we have a known one in the result
uint64_t zeros = known_zeros() | other.known_zeros();
uint64_t ones = known_ones() & other.known_ones();
return {zeros | ones, ones};
}

BitsKnown operator|(const BitsKnown &other) const {
uint64_t zeros = known_zeros() & other.known_zeros();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment like in the & case

uint64_t ones = known_ones() | other.known_ones();
return {zeros | ones, ones};
}

BitsKnown operator^(const BitsKnown &other) const {
// Unlike & and |, we need to know both bits to know anything.
uint64_t new_mask = mask & other.mask;
return {new_mask, (value ^ other.value) & new_mask};
}
};

BitsKnown to_bits_known(const Type &type) const;
void from_bits_known(BitsKnown known, const Type &type);
};

HALIDE_ALWAYS_INLINE
Expand Down
21 changes: 14 additions & 7 deletions src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,25 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {

// clang-format off
if (EVAL_IN_LAMBDA
(rewrite(select(IRMatcher::likely(true), x, y), x) ||
rewrite(select(IRMatcher::likely(false), x, y), y) ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), x) ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), y) ||
rewrite(select(1, x, y), x) ||
rewrite(select(0, x, y), y) ||
rewrite(select(x, y, y), y) ||
(rewrite(select(IRMatcher::likely(true), x, y), true_value) ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these changes meant to be part of this PR? What's going on here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the fuzzer with debugging turned on identified an issue here where information wasn't being propagated through a constant-folded select as aggressively as it could have been. It happened to have a bitwise op in it, which is why I was looking at the case and puzzling over the too-loose bounds.

rewrite(select(IRMatcher::likely(false), x, y), false_value) ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), true_value) ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), false_value) ||
rewrite(select(1, x, y), true_value) ||
rewrite(select(0, x, y), false_value) ||
rewrite(select(x, y, y), false_value) ||
rewrite(select(x, likely(y), y), false_value) ||
rewrite(select(x, y, likely(y)), true_value) ||
rewrite(select(x, likely_if_innermost(y), y), false_value) ||
rewrite(select(x, y, likely_if_innermost(y)), true_value) ||
false)) {
if (info) {
if (rewrite.result.same_as(true_value)) {
*info = t_info;
} else if (rewrite.result.same_as(false_value)) {
*info = f_info;
}
}
return rewrite.result;
}
// clang-format on
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tests(GROUPS correctness
autodiff.cpp
bad_likely.cpp
bit_counting.cpp
bits_known.cpp
bitwise_ops.cpp
bool_compute_root_vectorize.cpp
bool_predicate_cast.cpp
Expand Down
Loading
Loading