-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bounds and alignment analysis through bitwise ops #8574
base: main
Are you sure you want to change the base?
Changes from 5 commits
a8978b1
a377bc1
81c6d89
ead954a
ec7f7d7
0eb3162
09d95f0
dedd1c3
eea0042
19c3062
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -496,5 +496,114 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) { | |
return is_const_one(e); | ||
} | ||
|
||
Simplify::ExprInfo::BitsKnown Simplify::ExprInfo::to_bits_known(const Type &type) const { | ||
BitsKnown result = {0, 0}; | ||
|
||
if (!(type.is_int() || type.is_uint())) { | ||
// Let's not claim we know anything about the bit patterns of | ||
// non-integer types for now. | ||
return result; | ||
} | ||
|
||
// Identify the largest power of two in the modulus to get some low bits | ||
if (alignment.modulus) { | ||
result.mask = largest_power_of_two_factor(alignment.modulus) - 1; | ||
result.value = result.mask & alignment.remainder; | ||
} else { | ||
// This value is just a constant | ||
result.mask = (uint64_t)(-1); | ||
result.value = alignment.remainder; | ||
return result; | ||
} | ||
|
||
// The bounds and the type tell us a bunch of high bits are zero or one | ||
if (bounds >= 0) { | ||
if (type.is_int()) { | ||
// The sign bit and above are zero. | ||
result.mask |= (uint64_t)(-1) << (type.bits() - 1); | ||
} else if (type.bits() < 64) { | ||
// Narrow uints are zero-extended. | ||
result.mask |= (uint64_t)(-1) << type.bits(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this true regardless of the bounds? |
||
} | ||
if (bounds.max_defined) { | ||
// It's positive and the max is representable as an int64, so at least | ||
// one high bit is zero, but bounds.max isn't zero or we would have | ||
// returned above. | ||
result.mask |= (uint64_t)(-1) << (64 - clz64(bounds.max)); | ||
} | ||
} else if (bounds < 0) { | ||
// At least one high bit is one, but bounds.min isn't -1 or we would | ||
// have returned above. | ||
uint64_t high_bits = (uint64_t)(-1) << (type.bits() - 1); | ||
if (bounds.min_defined) { | ||
high_bits |= (uint64_t)(-1) << (64 - clz64(~bounds.min)); | ||
} | ||
result.mask |= high_bits; | ||
result.value |= high_bits; | ||
} | ||
|
||
return result; | ||
} | ||
|
||
void Simplify::ExprInfo::from_bits_known(Simplify::ExprInfo::BitsKnown known, const Type &type) { | ||
// Normalize everything to 64-bits by sign- or zero-extending known bits for | ||
// the type. | ||
uint64_t missing_bits = 0; | ||
if (type.bits() < 64) { | ||
missing_bits = (uint64_t)(-1) << type.bits(); | ||
} | ||
if (missing_bits) { | ||
if (type.is_uint()) { | ||
// For a uint the high bits are zero | ||
known.mask |= missing_bits; | ||
known.value &= ~missing_bits; | ||
} else if (type.is_int()) { | ||
// For an int we need to know the sign to know the high bits | ||
bool sign_bit_known = (known.mask >> (type.bits() - 1)) & 1; | ||
bool negative = (known.value >> (type.bits() - 1)) & 1; | ||
if (!sign_bit_known) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These cases need comments explaining each of them |
||
known.mask &= ~missing_bits; | ||
known.value &= ~missing_bits; | ||
} else if (negative) { | ||
known.mask |= missing_bits; | ||
known.value |= missing_bits; | ||
} else if (!negative) { | ||
known.mask |= missing_bits; | ||
known.value &= ~missing_bits; | ||
} | ||
} | ||
} | ||
|
||
// We can get the trailing one bits by adding one and taking the largest | ||
// power of two factor. Note that this works out correctly when we know all | ||
// the bits - the modulus comes out as zero, and the remainder is the entire | ||
// number, which is how we represent constants in ModulusRemainder. | ||
alignment.modulus = largest_power_of_two_factor(known.mask + 1); | ||
alignment.remainder = known.value & (alignment.modulus - 1); | ||
|
||
if ((int64_t)known.mask < 0) { | ||
// We know some leading bits | ||
|
||
// Set all unknown bits to zero | ||
uint64_t min_val = known.value & known.mask; | ||
// Set all unknown bits to one | ||
uint64_t max_val = known.value | ~known.mask; | ||
|
||
if (type.is_uint() && (int64_t)known.value < 0) { | ||
// We know it's out of range at the top end for our ConstantInterval | ||
// class. At the time of writing, to_bits_known can't produce this | ||
// directly, and bits_known is never propagated through other | ||
// operations, so this code is unreachable. Nonetheless we'll do the | ||
// best job we can at representing this case in case this code | ||
// becomes reachable in future. | ||
bounds = ConstantInterval::bounded_below((1ULL << 63) - 1); | ||
} else { | ||
// In all other cases, the bounds are representable as an int64 | ||
// and don't span zero (because we know the high bit). | ||
bounds = ConstantInterval{(int64_t)min_val, (int64_t)max_val}; | ||
} | ||
} | ||
} | ||
|
||
} // namespace Internal | ||
} // namespace Halide |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,14 +186,23 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { | |
return Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic); | ||
} | ||
} else if (op->is_intrinsic(Call::bitwise_and)) { | ||
Expr a = mutate(op->args[0], nullptr); | ||
Expr b = mutate(op->args[1], nullptr); | ||
ExprInfo a_info, b_info; | ||
Expr a = mutate(op->args[0], &a_info); | ||
Expr b = mutate(op->args[1], &b_info); | ||
|
||
Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); | ||
if (unbroadcast.defined()) { | ||
return mutate(unbroadcast, info); | ||
} | ||
|
||
if (info && (op->type.is_int() || op->type.is_uint())) { | ||
auto bits_known = a_info.to_bits_known(op->type) & b_info.to_bits_known(op->type); | ||
info->from_bits_known(bits_known, op->type); | ||
if (bits_known.mask == (uint64_t)-1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment saying something like "This is a constant". Or make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and same below with |
||
return make_const(op->type, bits_known.value); | ||
} | ||
} | ||
|
||
auto ia = as_const_int(a), ib = as_const_int(b); | ||
auto ua = as_const_uint(a), ub = as_const_uint(b); | ||
|
||
|
@@ -218,14 +227,23 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { | |
return a & b; | ||
} | ||
} else if (op->is_intrinsic(Call::bitwise_or)) { | ||
Expr a = mutate(op->args[0], nullptr); | ||
Expr b = mutate(op->args[1], nullptr); | ||
ExprInfo a_info, b_info; | ||
Expr a = mutate(op->args[0], &a_info); | ||
Expr b = mutate(op->args[1], &b_info); | ||
|
||
Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); | ||
if (unbroadcast.defined()) { | ||
return mutate(unbroadcast, info); | ||
} | ||
|
||
if (info && (op->type.is_int() || op->type.is_uint())) { | ||
auto bits_known = a_info.to_bits_known(op->type) | b_info.to_bits_known(op->type); | ||
info->from_bits_known(bits_known, op->type); | ||
if (bits_known.mask == (uint64_t)-1) { | ||
return make_const(op->type, bits_known.value); | ||
} | ||
} | ||
|
||
auto ia = as_const_int(a), ib = as_const_int(b); | ||
auto ua = as_const_uint(a), ub = as_const_uint(b); | ||
if (ia && ib) { | ||
|
@@ -238,13 +256,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { | |
return a | b; | ||
} | ||
} else if (op->is_intrinsic(Call::bitwise_not)) { | ||
Expr a = mutate(op->args[0], nullptr); | ||
ExprInfo a_info; | ||
Expr a = mutate(op->args[0], &a_info); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use the bits_known mask in the same way here that you do for and/or? negated constants are constants There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was, but then I changed it to not actually compute BitsKnown here. |
||
|
||
Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a}, op->call_type); | ||
if (unbroadcast.defined()) { | ||
return mutate(unbroadcast, info); | ||
} | ||
|
||
if (info && (op->type.is_int() || op->type.is_uint())) { | ||
// For the purpose of bounds and alignment, ~x can be treated as an | ||
// all-ones bit pattern minus x. | ||
Expr e = mutate(make_const(op->type, -1) - op->args[0], info); | ||
// If the result of this happens to be a constant, we can also just return it | ||
if (info->bounds.is_single_point()) { | ||
return e; | ||
} | ||
} | ||
|
||
if (auto ia = as_const_int(a)) { | ||
return make_const(op->type, ~(*ia)); | ||
} else if (auto ua = as_const_uint(a)) { | ||
|
@@ -255,14 +284,20 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { | |
return ~a; | ||
} | ||
} else if (op->is_intrinsic(Call::bitwise_xor)) { | ||
Expr a = mutate(op->args[0], nullptr); | ||
Expr b = mutate(op->args[1], nullptr); | ||
ExprInfo a_info, b_info; | ||
Expr a = mutate(op->args[0], &a_info); | ||
Expr b = mutate(op->args[1], &b_info); | ||
|
||
Expr unbroadcast = lift_elementwise_broadcasts(op->type, op->name, {a, b}, op->call_type); | ||
if (unbroadcast.defined()) { | ||
return mutate(unbroadcast, info); | ||
} | ||
|
||
if (info && (op->type.is_int() || op->type.is_uint())) { | ||
auto bits_known = a_info.to_bits_known(op->type) ^ b_info.to_bits_known(op->type); | ||
info->from_bits_known(bits_known, op->type); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, check for a constant? |
||
} | ||
|
||
auto ia = as_const_int(a), ib = as_const_int(b); | ||
auto ua = as_const_uint(a), ub = as_const_uint(b); | ||
if (ia && ib) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,24 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> { | |
} | ||
} | ||
|
||
uint64_t largest_power_of_two_factor(uint64_t x) const { | ||
// Consider the bits of x from MSB to LSB. Say there are three | ||
// trailing zeros, and the four high bits are unknown: | ||
// a b c d 1 0 0 0 | ||
// The largest power of two factor of a number is the trailing bits | ||
// up to and including the first 1. In this example that's 1000 | ||
// (i.e. 8). | ||
// Negating is flipping the bits and adding one. First we flip: | ||
// ~a ~b ~c ~d 0 1 1 1 | ||
// Then we add one: | ||
// ~a ~b ~c ~d 1 0 0 0 | ||
// If we bitwise and this with the original, the unknown bits cancel | ||
// out, and we get left with just the largest power of two | ||
// factor. If we want a mask of the trailing zeros instead, we can | ||
// just subtract one. | ||
return x & -x; | ||
} | ||
|
||
void cast_to(Type t) { | ||
if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) { | ||
return; | ||
|
@@ -96,10 +114,8 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> { | |
// representable as any 64-bit integer type, so there's no | ||
// wraparound. | ||
if (alignment.modulus > 0) { | ||
// This masks off all bits except for the lowest set one, | ||
// giving the largest power-of-two factor of a number. | ||
alignment.modulus &= -alignment.modulus; | ||
alignment.remainder = mod_imp(alignment.remainder, alignment.modulus); | ||
alignment.modulus = largest_power_of_two_factor(alignment.modulus); | ||
alignment.remainder &= alignment.modulus - 1; | ||
} | ||
} else { | ||
// A narrowing integer cast that could possibly overflow adds | ||
|
@@ -125,6 +141,52 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> { | |
alignment = ModulusRemainder::intersect(alignment, other.alignment); | ||
trim_bounds_using_alignment(); | ||
} | ||
|
||
// An alternative representation for information about integers is that | ||
// certain bits have known values in the 2s complement | ||
// representation. This is a useful form for analyzing bitwise ops, so | ||
// we provide conversions to and from that representation. For narrow | ||
// types, this represent what the bits would be if they were sign or | ||
// zero-extended to 64 bits, so for uints the high bits are known to be | ||
// zero, and for ints it depends on whether or not we knew the high bit | ||
// to begin with. | ||
struct BitsKnown { | ||
// A mask which is 1 where we know the value of that bit | ||
uint64_t mask; | ||
// The actual value of the known bits | ||
uint64_t value; | ||
|
||
uint64_t known_zeros() const { | ||
return mask & ~value; | ||
} | ||
|
||
uint64_t known_ones() const { | ||
return mask & value; | ||
} | ||
|
||
BitsKnown operator&(const BitsKnown &other) const { | ||
// Where either has known zeros, we have known zeros in the result | ||
// Where both have a known one, we have a known one in the result | ||
uint64_t zeros = known_zeros() | other.known_zeros(); | ||
uint64_t ones = known_ones() & other.known_ones(); | ||
return {zeros | ones, ones}; | ||
} | ||
|
||
BitsKnown operator|(const BitsKnown &other) const { | ||
uint64_t zeros = known_zeros() & other.known_zeros(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment like in the |
||
uint64_t ones = known_ones() | other.known_ones(); | ||
return {zeros | ones, ones}; | ||
} | ||
|
||
BitsKnown operator^(const BitsKnown &other) const { | ||
// Unlike & and |, we need to know both bits to know anything. | ||
uint64_t new_mask = mask & other.mask; | ||
return {new_mask, (value ^ other.value) & new_mask}; | ||
} | ||
}; | ||
|
||
BitsKnown to_bits_known(const Type &type) const; | ||
void from_bits_known(BitsKnown known, const Type &type); | ||
}; | ||
|
||
HALIDE_ALWAYS_INLINE | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,18 +21,25 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) { | |
|
||
// clang-format off | ||
if (EVAL_IN_LAMBDA | ||
(rewrite(select(IRMatcher::likely(true), x, y), x) || | ||
rewrite(select(IRMatcher::likely(false), x, y), y) || | ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), x) || | ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), y) || | ||
rewrite(select(1, x, y), x) || | ||
rewrite(select(0, x, y), y) || | ||
rewrite(select(x, y, y), y) || | ||
(rewrite(select(IRMatcher::likely(true), x, y), true_value) || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these changes meant to be part of this PR? What's going on here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Running the fuzzer with debugging turned on identified an issue here where information wasn't being propagated through a constant-folded select as aggressively as it could have been. It happened to have a bitwise op in it, which is why I was looking at the case and puzzling over the too-loose bounds. |
||
rewrite(select(IRMatcher::likely(false), x, y), false_value) || | ||
rewrite(select(IRMatcher::likely_if_innermost(true), x, y), true_value) || | ||
rewrite(select(IRMatcher::likely_if_innermost(false), x, y), false_value) || | ||
rewrite(select(1, x, y), true_value) || | ||
rewrite(select(0, x, y), false_value) || | ||
rewrite(select(x, y, y), false_value) || | ||
rewrite(select(x, likely(y), y), false_value) || | ||
rewrite(select(x, y, likely(y)), true_value) || | ||
rewrite(select(x, likely_if_innermost(y), y), false_value) || | ||
rewrite(select(x, y, likely_if_innermost(y)), true_value) || | ||
false)) { | ||
if (info) { | ||
if (rewrite.result.same_as(true_value)) { | ||
*info = t_info; | ||
} else if (rewrite.result.same_as(false_value)) { | ||
*info = f_info; | ||
} | ||
} | ||
return rewrite.result; | ||
} | ||
// clang-format on | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of this bitwise math is a little tricky to follow. Have you thrown this in an SMT solver? I think this should be verified