-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathIncrementalMerkleTree.sol
192 lines (161 loc) · 5.16 KB
/
IncrementalMerkleTree.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
library IncrementalMerkleTree {
using IncrementalMerkleTree for Tree;
struct Tree {
bytes32[][] nodes;
}
/**
* @notice query number of elements contained in tree
* @param t Tree struct storage reference
* @return treeSize size of tree
*/
function size(Tree storage t) internal view returns (uint256 treeSize) {
if (t.height() > 0) {
treeSize = t.nodes[0].length;
}
}
/**
* @notice query one-indexed height of tree
* @dev conventional zero-indexed height would require the use of signed integers, so height is one-indexed instead
* @param t Tree struct storage reference
* @return one-indexed height of tree
*/
function height(Tree storage t) internal view returns (uint256) {
return t.nodes.length;
}
/**
* @notice query Merkle root
* @param t Tree struct storage reference
* @return hash root hash
*/
function root(Tree storage t) internal view returns (bytes32 hash) {
uint256 treeHeight = t.height();
if (treeHeight > 0) {
unchecked {
hash = t.nodes[treeHeight - 1][0];
}
}
}
function at(Tree storage t, uint256 index)
internal
view
returns (bytes32 hash)
{
hash = t.nodes[0][index];
}
/**
* @notice add new element to tree
* @param t Tree struct storage reference
* @param hash to add
*/
function push(Tree storage t, bytes32 hash) internal {
unchecked {
uint256 treeHeight = t.height();
uint256 treeSize = t.size();
// add new layer if tree is at capacity
if (treeSize == (1 << treeHeight) >> 1) {
t.nodes.push();
treeHeight++;
}
// add new columns if rows are full
uint256 row;
uint256 col = treeSize;
while (row < treeHeight && t.nodes[row].length <= col) {
t.nodes[row].push();
row++;
col >>= 1;
}
// add hash to tree
t.set(treeSize, hash);
}
}
function pop(Tree storage t) internal {
uint256 treeHeight = t.height();
uint256 treeSize = t.size() - 1;
// remove layer if tree has excess capacity
if (treeSize == (1 << treeHeight) >> 2) {
treeHeight--;
t.nodes.pop();
}
// remove columns if rows are too long
uint256 row;
uint256 col = treeSize;
while (row < treeHeight && t.nodes[row].length > col) {
t.nodes[row].pop();
row++;
col = (col + 1) >> 1;
}
// recalculate hashes
if (treeSize > 0) {
t.set(treeSize - 1, t.at(treeSize - 1));
}
}
/**
* @notice update existing element in tree
* @param t Tree struct storage reference
* @param index index to update
* @param hash new hash to add
*/
function set(
Tree storage t,
uint256 index,
bytes32 hash
) internal {
unchecked {
_set(t.nodes, 0, index, t.size(), hash);
}
}
/**
* @notice update element in tree and recursively recalculate hashes
* @param nodes internal tree structure storage reference
* @param rowIndex index of current row to update
* @param colIndex index of current column to update
* @param rowLength length of row at rowIndex
* @param hash hash to store at current position
*/
function _set(
bytes32[][] storage nodes,
uint256 rowIndex,
uint256 colIndex,
uint256 rowLength,
bytes32 hash
) private {
bytes32[] storage row = nodes[rowIndex];
// store hash in array via assembly to avoid array length sload
assembly {
mstore(0x00, row.slot)
sstore(add(keccak256(0x00, 0x20), colIndex), hash)
}
if (rowLength == 1) return;
unchecked {
if (colIndex & 1 == 1) {
// sibling is on the left
assembly {
mstore(0x00, row.slot)
let sibling := sload(
add(keccak256(0x00, 0x20), sub(colIndex, 1))
)
mstore(0x00, sibling)
mstore(0x20, hash)
hash := keccak256(0x00, 0x40)
}
} else if (colIndex < rowLength - 1) {
// sibling is on the right (and sibling exists)
assembly {
mstore(0x00, row.slot)
let sibling := sload(
add(keccak256(0x00, 0x20), add(colIndex, 1))
)
mstore(0x00, hash)
mstore(0x20, sibling)
hash := keccak256(0x00, 0x40)
}
}
rowLength = rowLength % 2 == 0
? rowLength >> 1
: (rowLength >> 1) + 1;
_set(nodes, rowIndex + 1, colIndex >> 1, rowLength, hash);
}
}
}