Skip to content

Bounds and alignment analysis through bitwise ops #8574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ ConstantInterval bounds_helper(const Expr &e,
ConstantInterval cq = recurse(op->args[2]);
ConstantInterval rounding_term = 1 << (cq - 1);
return (ca * cb + rounding_term) >> cq;
} else if (op->is_intrinsic(Call::bitwise_not)) {
// We can't do much with the other bitwise ops, but we can treat
// bitwise_not as an all-ones bit pattern minus the argument.
return recurse(make_const(e.type(), -1) - op->args[0]);
}
// If you add a new intrinsic here, also add it to the expression
// generator in test/correctness/lossless_cast.cpp
Expand Down
109 changes: 109 additions & 0 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,5 +496,114 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) {
return is_const_one(e);
}

Simplify::ExprInfo::BitsKnown Simplify::ExprInfo::to_bits_known(const Type &type) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this bitwise math is a little tricky to follow. Have you thrown this in an SMT solver? I think this should be verified

BitsKnown result = {0, 0};

if (!(type.is_int() || type.is_uint())) {
// Let's not claim we know anything about the bit patterns of
// non-integer types for now.
return result;
}

// Identify the largest power of two in the modulus to get some low bits
if (alignment.modulus) {
result.mask = largest_power_of_two_factor(alignment.modulus) - 1;
result.value = result.mask & alignment.remainder;
} else {
// This value is just a constant
result.mask = (uint64_t)(-1);
result.value = alignment.remainder;
return result;
}

// The bounds and the type tell us a bunch of high bits are zero or one
if (bounds >= 0) {
if (type.is_int()) {
// The sign bit and above are zero.
result.mask |= (uint64_t)(-1) << (type.bits() - 1);
} else if (type.bits() < 64) {
// Narrow uints are zero-extended.
result.mask |= (uint64_t)(-1) << type.bits();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this true regardless of the bounds?

}
if (bounds.max_defined) {
// It's positive and the max is representable as an int64, so at least
// one high bit is zero, but bounds.max isn't zero or we would have
// returned above.
result.mask |= (uint64_t)(-1) << (64 - clz64(bounds.max));
}
} else if (bounds < 0) {
// At least one high bit is one, but bounds.min isn't -1 or we would
// have returned above.
uint64_t high_bits = (uint64_t)(-1) << (type.bits() - 1);
if (bounds.min_defined) {
high_bits |= (uint64_t)(-1) << (64 - clz64(~bounds.min));
}
result.mask |= high_bits;
result.value |= high_bits;
}

return result;
}

void Simplify::ExprInfo::from_bits_known(Simplify::ExprInfo::BitsKnown known, const Type &type) {
// Normalize everything to 64-bits by sign- or zero-extending known bits for
// the type.
uint64_t missing_bits = 0;
if (type.bits() < 64) {
missing_bits = (uint64_t)(-1) << type.bits();
}
if (missing_bits) {
if (type.is_uint()) {
// For a uint the high bits are zero
known.mask |= missing_bits;
known.value &= ~missing_bits;
} else if (type.is_int()) {
// For an int we need to know the sign to know the high bits
bool sign_bit_known = (known.mask >> (type.bits() - 1)) & 1;
bool negative = (known.value >> (type.bits() - 1)) & 1;
if (!sign_bit_known) {
known.mask &= ~missing_bits;
known.value &= ~missing_bits;
} else if (negative) {
known.mask |= missing_bits;
known.value |= missing_bits;
} else if (!negative) {
known.mask |= missing_bits;
known.value &= ~missing_bits;
}
}
}

// We can get the trailing one bits by adding one and taking the largest
// power of two factor. Note that this works out correctly when we know all
// the bits - the modulus comes out as zero, and the remainder is the entire
// number, which is how we represent constants in ModulusRemainder.
alignment.modulus = largest_power_of_two_factor(known.mask + 1);
alignment.remainder = known.value & (alignment.modulus - 1);

if ((int64_t)known.mask < 0) {
// We know some leading bits

// Set all unknown bits to zero
uint64_t min_val = known.value & known.mask;
// Set all unknown bits to one
uint64_t max_val = known.value | ~known.mask;

if (type.is_uint() && (int64_t)known.value < 0) {
// We know it's out of range at the top end for our ConstantInterval
// class. At the time of writing, to_bits_known can't produce this
// directly, and bits_known is never propagated through other
// operations, so this code is unreachable. Nonetheless we'll do the
// best job we can at representing this case in case this code
// becomes reachable in future.
bounds = ConstantInterval::bounded_below((1ULL << 63) - 1);
} else {
// In all other cases, the bounds are representable as an int64
// and don't span zero (because we know the high bit).
bounds = ConstantInterval{(int64_t)min_val, (int64_t)max_val};
}
}
}

} // namespace Internal
} // namespace Halide
49 changes: 42 additions & 7 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,23 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic);
}
} else if (op->is_intrinsic(Call::bitwise_and)) {
Expr a = mutate(op->args[0], nullptr);
Expr b = mutate(op->args[1], nullptr);
ExprInfo a_info, b_info;
Expr a = mutate(op->args[0], &a_info);
Expr b = mutate(op->args[1], &b_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
auto bits_known = a_info.to_bits_known(op->type) & b_info.to_bits_known(op->type);
info->from_bits_known(bits_known, op->type);
if (bits_known.mask == (uint64_t)-1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment saying something like "This is a constant". Or make ExprInfo::BitsKnown have a helper function like "std::optional<uint64_t> as_constant()`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and same below with bitwise_or

return make_const(op->type, bits_known.value);
}
}

auto ia = as_const_int(a), ib = as_const_int(b);
auto ua = as_const_uint(a), ub = as_const_uint(b);

Expand All @@ -218,14 +227,23 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return a & b;
}
} else if (op->is_intrinsic(Call::bitwise_or)) {
Expr a = mutate(op->args[0], nullptr);
Expr b = mutate(op->args[1], nullptr);
ExprInfo a_info, b_info;
Expr a = mutate(op->args[0], &a_info);
Expr b = mutate(op->args[1], &b_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
auto bits_known = a_info.to_bits_known(op->type) | b_info.to_bits_known(op->type);
info->from_bits_known(bits_known, op->type);
if (bits_known.mask == (uint64_t)-1) {
return make_const(op->type, bits_known.value);
}
}

auto ia = as_const_int(a), ib = as_const_int(b);
auto ua = as_const_uint(a), ub = as_const_uint(b);
if (ia && ib) {
Expand All @@ -238,13 +256,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return a | b;
}
} else if (op->is_intrinsic(Call::bitwise_not)) {
Expr a = mutate(op->args[0], nullptr);
ExprInfo a_info;
Expr a = mutate(op->args[0], &a_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
// For the purpose of bounds and alignment, ~x can be treated as an
// all-ones bit pattern minus x.
Expr e = mutate(make_const(op->type, -1) - op->args[0], info);
// If the result of this happens to be a constant, we can also just return it
if (info->bounds.is_single_point()) {
return e;
}
}

if (auto ia = as_const_int(a)) {
return make_const(op->type, ~(*ia));
} else if (auto ua = as_const_uint(a)) {
Expand All @@ -255,14 +284,20 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
return ~a;
}
} else if (op->is_intrinsic(Call::bitwise_xor)) {
Expr a = mutate(op->args[0], nullptr);
Expr b = mutate(op->args[1], nullptr);
ExprInfo a_info, b_info;
Expr a = mutate(op->args[0], &a_info);
Expr b = mutate(op->args[1], &b_info);

Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type);
if (unbroadcast.defined()) {
return mutate(unbroadcast, info);
}

if (info && (op->type.is_int() || op->type.is_uint())) {
auto bits_known = a_info.to_bits_known(op->type) ^ b_info.to_bits_known(op->type);
info->from_bits_known(bits_known, op->type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, check for a constant?

}

auto ia = as_const_int(a), ib = as_const_int(b);
auto ua = as_const_uint(a), ub = as_const_uint(b);
if (ia && ib) {
Expand Down
70 changes: 66 additions & 4 deletions src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
}
}

uint64_t largest_power_of_two_factor(uint64_t x) const {
// Consider the bits of x from MSB to LSB. Say there are three
// trailing zeros, and the four high bits are unknown:
// a b c d 1 0 0 0
// The largest power of two factor of a number is the trailing bits
// up to and including the first 1. In this example that's 1000
// (i.e. 8).
// Negating is flipping the bits and adding one. First we flip:
// ~a ~b ~c ~d 0 1 1 1
// Then we add one:
// ~a ~b ~c ~d 1 0 0 0
// If we bitwise and this with the original, the unknown bits cancel
// out, and we get left with just the largest power of two
// factor. If we want a mask of the trailing zeros instead, we can
// just subtract one.
return x & -x;
}

void cast_to(Type t) {
if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) {
return;
Expand All @@ -96,10 +114,8 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
// representable as any 64-bit integer type, so there's no
// wraparound.
if (alignment.modulus > 0) {
// This masks off all bits except for the lowest set one,
// giving the largest power-of-two factor of a number.
alignment.modulus &= -alignment.modulus;
alignment.remainder = mod_imp(alignment.remainder, alignment.modulus);
alignment.modulus = largest_power_of_two_factor(alignment.modulus);
alignment.remainder &= alignment.modulus - 1;
}
} else {
// A narrowing integer cast that could possibly overflow adds
Expand All @@ -125,6 +141,52 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
alignment = ModulusRemainder::intersect(alignment, other.alignment);
trim_bounds_using_alignment();
}

// An alternative representation for information about integers is that
// certain bits have known values in the 2s complement
// representation. This is a useful form for analyzing bitwise ops, so
// we provide conversions to and from that representation. For narrow
// types, this represent what the bits would be if they were sign or
// zero-extended to 64 bits, so for uints the high bits are known to be
// zero, and for ints it depends on whether or not we knew the high bit
// to begin with.
struct BitsKnown {
// A mask which is 1 where we know the value of that bit
uint64_t mask;
// The actual value of the known bits
uint64_t value;

uint64_t known_zeros() const {
return mask & ~value;
}

uint64_t known_ones() const {
return mask & value;
}

BitsKnown operator&(const BitsKnown &other) const {
// Where either has known zeros, we have known zeros in the result
// Where both have a known one, we have a known one in the result
uint64_t zeros = known_zeros() | other.known_zeros();
uint64_t ones = known_ones() & other.known_ones();
return {zeros | ones, ones};
}

BitsKnown operator|(const BitsKnown &other) const {
uint64_t zeros = known_zeros() & other.known_zeros();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment like in the & case

uint64_t ones = known_ones() | other.known_ones();
return {zeros | ones, ones};
}

BitsKnown operator^(const BitsKnown &other) const {
// Unlike & and |, we need to know both bits to know anything.
uint64_t new_mask = mask & other.mask;
return {new_mask, (value ^ other.value) & new_mask};
}
};

BitsKnown to_bits_known(const Type &type) const;
void from_bits_known(BitsKnown known, const Type &type);
};

HALIDE_ALWAYS_INLINE
Expand Down
21 changes: 14 additions & 7 deletions src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,25 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {

// clang-format off
if (EVAL_IN_LAMBDA
(rewrite(select(IRMatcher::likely(true), x, y), x) ||
rewrite(select(IRMatcher::likely(false), x, y), y) ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), x) ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), y) ||
rewrite(select(1, x, y), x) ||
rewrite(select(0, x, y), y) ||
rewrite(select(x, y, y), y) ||
(rewrite(select(IRMatcher::likely(true), x, y), true_value) ||
rewrite(select(IRMatcher::likely(false), x, y), false_value) ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), true_value) ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), false_value) ||
rewrite(select(1, x, y), true_value) ||
rewrite(select(0, x, y), false_value) ||
rewrite(select(x, y, y), false_value) ||
rewrite(select(x, likely(y), y), false_value) ||
rewrite(select(x, y, likely(y)), true_value) ||
rewrite(select(x, likely_if_innermost(y), y), false_value) ||
rewrite(select(x, y, likely_if_innermost(y)), true_value) ||
false)) {
if (info) {
if (rewrite.result.same_as(true_value)) {
*info = t_info;
} else if (rewrite.result.same_as(false_value)) {
*info = f_info;
}
}
return rewrite.result;
}
// clang-format on
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tests(GROUPS correctness
autodiff.cpp
bad_likely.cpp
bit_counting.cpp
bits_known.cpp
bitwise_ops.cpp
bool_compute_root_vectorize.cpp
bool_predicate_cast.cpp
Expand Down
Loading
Loading