Skip to content

Commit 00eda7c

Browse files
authored
Merge pull request #251 from solidstate-network/sqrt-max-fix
`Math#sqrt` fix for max `uint256` input
2 parents 530c10a + 704acf3 commit 00eda7c

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

contracts/utils/Math.sol

+20-9
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,27 @@ library Math {
4747

4848
/**
4949
* @notice estimate square root of number
50-
* @dev uses Babylonian method (https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method)
51-
* @param x input number
52-
* @return y square root
50+
* @dev uses Heron's method (https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Heron's_method)
51+
* @param n input number
52+
* @return root square root of input (rounded down to nearest uint256)
5353
*/
54-
function sqrt(uint256 x) internal pure returns (uint256 y) {
55-
uint256 z = (x + 1) >> 1;
56-
y = x;
57-
while (z < y) {
58-
y = z;
59-
z = (x / z + z) >> 1;
54+
function sqrt(uint256 n) internal pure returns (uint256 root) {
55+
unchecked {
56+
// begin with an upper bound, to be updated each time a better estimate is found
57+
// for inputs of 0 and 1, this will be returned as-is
58+
root = n;
59+
// calculate an overestimate
60+
// bitwise-or prevents zero division in the case of input of 1
61+
uint256 estimate = (n >> 1) | 1;
62+
63+
// as long as estimate continues to decrease, it is converging on the square root
64+
65+
while (estimate < root) {
66+
// track the new best estimate as the prospective output value
67+
root = estimate;
68+
// attempt to find a better estimate
69+
estimate = (root + n / root) >> 1;
70+
}
6071
}
6172
}
6273
}

test/utils/Math.ts

+31-6
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,42 @@ describe('Math', () => {
7070
});
7171

7272
describe('#sqrt(uint256)', () => {
73-
it('returns the sqrt of a positive integer from 0 to maxUint256', async () => {
74-
expect(await instance.sqrt.staticCall(16)).to.eq(4);
73+
it('returns the square root of 0', async () => {
74+
expect(await instance.sqrt.staticCall(0n)).to.eq(0n);
75+
});
76+
77+
it('returns the square root of 1', async () => {
78+
expect(await instance.sqrt.staticCall(1n)).to.eq(1n);
79+
});
80+
81+
it('returns the square root of 2', async () => {
82+
expect(await instance.sqrt.staticCall(2n)).to.eq(1n);
83+
});
7584

76-
for (let i = 10; i < 16; i++) {
77-
expect(await instance.sqrt.staticCall(i.toString())).to.eq(3);
85+
it('returns the square root of positive integers', async () => {
86+
for (let i = 2; i < 16; i++) {
87+
expect(await instance.sqrt.staticCall(BigInt(i))).to.eq(
88+
Math.floor(Math.sqrt(i)),
89+
);
7890
}
91+
});
7992

80-
expect(await instance.sqrt.staticCall(0)).to.eq(0);
93+
it('returns the square root of powers of 2', async () => {
94+
for (let i = 0; i < 256; i++) {
95+
const input = 2n ** BigInt(i);
96+
const output = await instance.sqrt.staticCall(input);
97+
expect(output ** 2n).to.be.lte(input);
98+
expect((output + 1n) ** 2n).to.be.gt(input);
99+
}
100+
});
81101

102+
it('returns the square root of max values', async () => {
82103
expect(await instance.sqrt.staticCall(ethers.MaxUint256 - 1n)).to.eq(
83-
BigInt('340282366920938463463374607431768211455'),
104+
340282366920938463463374607431768211455n,
105+
);
106+
107+
expect(await instance.sqrt.staticCall(ethers.MaxUint256)).to.eq(
108+
340282366920938463463374607431768211455n,
84109
);
85110
});
86111
});

0 commit comments

Comments
 (0)