-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bounds and alignment analysis through bitwise ops #8574
Open
abadams
wants to merge
10
commits into
main
Choose a base branch
from
abadams/bits_known
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a8978b1
Analyze bitwise ops using a mask of bits known
abadams a377bc1
Bug fixes and better analysis of bitwise not
abadams 81c6d89
Reorder test to know one more bit about narrow positive ints
abadams ead954a
Slight simplification with no functional change
abadams ec7f7d7
Remove debugging print
abadams 0eb3162
Propagate info in more places
abadams 09d95f0
Merge remote-tracking branch 'origin/main' into abadams/bits_known
abadams dedd1c3
empty commit
abadams eea0042
Merge remote-tracking branch 'origin/main' into abadams/bits_known
abadams 19c3062
empty commit
abadams File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,7 +91,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) { | |
Simplify::VarInfo info; | ||
info.old_uses = info.new_uses = 0; | ||
if (const Variable *v = fact.as<Variable>()) { | ||
info.replacement = const_false(fact.type().lanes()); | ||
info.replacement = Halide::Internal::const_false(fact.type().lanes()); | ||
simplify->var_info.push(v->name, info); | ||
pop_list.push_back(v); | ||
} else if (const NE *ne = fact.as<NE>()) { | ||
|
@@ -178,7 +178,7 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { | |
Simplify::VarInfo info; | ||
info.old_uses = info.new_uses = 0; | ||
if (const Variable *v = fact.as<Variable>()) { | ||
info.replacement = const_true(fact.type().lanes()); | ||
info.replacement = Halide::Internal::const_true(fact.type().lanes()); | ||
simplify->var_info.push(v->name, info); | ||
pop_list.push_back(v); | ||
} else if (const EQ *eq = fact.as<EQ>()) { | ||
|
@@ -496,5 +496,158 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) { | |
return is_const_one(e); | ||
} | ||
|
||
Simplify::ExprInfo::BitsKnown Simplify::ExprInfo::to_bits_known(const Type &type) const { | ||
BitsKnown result = {0, 0}; | ||
|
||
if (!(type.is_int() || type.is_uint())) { | ||
// Let's not claim we know anything about the bit patterns of | ||
// non-integer types for now. | ||
return result; | ||
} | ||
|
||
// Identify the largest power of two in the modulus to get some low bits | ||
if (alignment.modulus) { | ||
result.mask = largest_power_of_two_factor(alignment.modulus) - 1; | ||
result.value = result.mask & alignment.remainder; | ||
} else { | ||
// This value is just a constant | ||
result.mask = (uint64_t)(-1); | ||
result.value = alignment.remainder; | ||
return result; | ||
} | ||
|
||
// Compute a mask which is 1 for all the leading zeros of a uint64 | ||
auto leading_zeros_mask = [](uint64_t x) { | ||
if (x == 0) { | ||
// They're all leading zeros, but clz64 is UB on zero. Really we | ||
// should have returned early above, but it's hard to guarantee that | ||
// the alignment analysis catches constants at the same time as | ||
// bounds analysis does. | ||
return (uint64_t)-1; | ||
} | ||
return (uint64_t)(-1) << (64 - clz64(x)); | ||
}; | ||
|
||
auto leading_ones_mask = [=](uint64_t x) { | ||
return leading_zeros_mask(~x); | ||
}; | ||
|
||
// The bounds and the type tell us a bunch of high bits are zero or one | ||
if (type.is_uint()) { | ||
// Narrow uints are always zero-extended. | ||
if (type.bits() < 64) { | ||
result.mask |= (uint64_t)(-1) << type.bits(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this true regardless of the bounds? |
||
} | ||
// A lower bound might tell us that there are some leading ones, and an | ||
// upper bound might tell us that there are some leading | ||
// zeros. Unfortunately we'll never learn about leading ones, because | ||
// uint64_ts that start with leading ones can't have a min represented | ||
// as an int64_t, which is what ConstantInverval uses, so | ||
// bounds.min_defined will never be true. | ||
if (bounds.max_defined) { | ||
result.mask |= leading_zeros_mask(bounds.max); | ||
} | ||
} else { | ||
internal_assert(type.is_int()); | ||
// A mask which is 1 for the sign bit and above. | ||
uint64_t sign_bit_and_above = (uint64_t)(-1) << (type.bits() - 1); | ||
if (bounds >= 0) { | ||
// We know this int is positive, so the sign bit and above are zero. | ||
result.mask |= sign_bit_and_above; | ||
if (bounds.max_defined) { | ||
// We also have an upper bound, so there may be more zero bits, | ||
// depending on how many leading zeros there are in the upper | ||
// bound. | ||
result.mask |= leading_zeros_mask(bounds.max); | ||
} | ||
} else if (bounds < 0) { | ||
// This int is negative, so the sign bit and above are one. | ||
result.mask |= sign_bit_and_above; | ||
result.value |= sign_bit_and_above; | ||
if (bounds.min_defined) { | ||
// We have a lower bound, so there may be more leading one bits, | ||
// depending on how many leading ones there are in the lower | ||
// bound. | ||
uint64_t leading_ones = leading_ones_mask(bounds.min); | ||
result.mask |= leading_ones; | ||
result.value |= leading_ones; | ||
} | ||
} | ||
} | ||
|
||
return result; | ||
} | ||
|
||
void Simplify::ExprInfo::from_bits_known(Simplify::ExprInfo::BitsKnown known, const Type &type) { | ||
// Normalize everything to 64-bits by sign- or zero-extending known bits for | ||
// the type. | ||
|
||
// A mask which is one for all the new bits resulting from sign or zero | ||
// extension. | ||
uint64_t missing_bits = 0; | ||
if (type.bits() < 64) { | ||
missing_bits = (uint64_t)(-1) << type.bits(); | ||
} | ||
|
||
if (missing_bits) { | ||
if (type.is_uint()) { | ||
// For a uint the high bits are known to be zero | ||
known.mask |= missing_bits; | ||
known.value &= ~missing_bits; | ||
} else if (type.is_int()) { | ||
// For an int we need to know the sign to know the high bits | ||
bool sign_bit_known = (known.mask >> (type.bits() - 1)) & 1; | ||
bool negative = (known.value >> (type.bits() - 1)) & 1; | ||
if (!sign_bit_known) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These cases need comments explaining each of them |
||
// We don't know the sign bit, so we don't know any of the | ||
// extended bits. Mark them as unknown in the mask and zero them | ||
// out in the value too just for ease of debugging. | ||
known.mask &= ~missing_bits; | ||
known.value &= ~missing_bits; | ||
} else if (negative) { | ||
// We know the sign bit is 1, so all of the extended bits are 1 | ||
// too. | ||
known.mask |= missing_bits; | ||
known.value |= missing_bits; | ||
} else if (!negative) { | ||
// We know the sign bit is zero, so all of the extended bits are | ||
// zero too. | ||
known.mask |= missing_bits; | ||
known.value &= ~missing_bits; | ||
} | ||
} | ||
} | ||
|
||
// We can get the trailing one bits by adding one and taking the largest | ||
// power of two factor. Note that this works out correctly when we know all | ||
// the bits - the modulus comes out as zero, and the remainder is the entire | ||
// number, which is how we represent constants in ModulusRemainder. | ||
alignment.modulus = largest_power_of_two_factor(known.mask + 1); | ||
alignment.remainder = known.value & (alignment.modulus - 1); | ||
|
||
if ((int64_t)known.mask < 0) { | ||
// We know some leading bits | ||
|
||
// Set all unknown bits to zero | ||
uint64_t min_val = known.value & known.mask; | ||
// Set all unknown bits to one | ||
uint64_t max_val = known.value | ~known.mask; | ||
|
||
if (type.is_uint() && (int64_t)known.value < 0) { | ||
// We know it's out of range at the top end for our ConstantInterval | ||
// class. At the time of writing, to_bits_known can't produce this | ||
// directly, and bits_known is never propagated through other | ||
// operations, so this code is unreachable. Nonetheless we'll do the | ||
// best job we can at representing this case in case this code | ||
// becomes reachable in future. | ||
bounds = ConstantInterval::bounded_below((1ULL << 63) - 1); | ||
} else { | ||
// In all other cases, the bounds are representable as an int64 | ||
// and don't span zero (because we know the high bit). | ||
bounds = ConstantInterval{(int64_t)min_val, (int64_t)max_val}; | ||
} | ||
} | ||
} | ||
|
||
} // namespace Internal | ||
} // namespace Halide |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of this bitwise math is a little tricky to follow. Have you thrown this in an SMT solver? I think this should be verified