Skip to content

Commit 272600b

Browse files
authored
Avoid allocations when building a tree (#29)
* Avoid allocations when building a tree * Review fixes
1 parent 7f6264f commit 272600b

File tree

4 files changed

+67
-30
lines changed

4 files changed

+67
-30
lines changed

merkle.go

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type node struct {
3737
}
3838

3939
func (n node) IsEmpty() bool {
40-
return n.value == nil
40+
return len(n.value) == 0
4141
}
4242

4343
// layer is a layer in the merkle tree.
@@ -98,6 +98,7 @@ type Tree struct {
9898
leavesToProve *sparseBoolStack
9999
cacheWriter CacheWriter
100100
minHeight uint
101+
parentBuf []byte
101102
}
102103

103104
// AddLeaf incorporates a new leaf to the state of the tree. It updates the state required to eventually determine the
@@ -108,7 +109,6 @@ func (t *Tree) AddLeaf(value []byte) error {
108109
OnProvenPath: t.leavesToProve.Pop(),
109110
}
110111
l := t.baseLayer
111-
var parent, lChild, rChild node
112112
var lastCachingError error
113113

114114
// Loop through the layers, starting from the base layer.
@@ -124,26 +124,29 @@ func (t *Tree) AddLeaf(value []byte) error {
124124
// If no node is pending, then this node is a left sibling,
125125
// pending for its right sibling before its parent can be calculated.
126126
if l.parking.IsEmpty() {
127-
l.parking = n
127+
// Copy the byte slice as we will keep it for a while.
128+
l.parking.value = append(l.parking.value[:0], n.value...)
129+
l.parking.OnProvenPath = n.OnProvenPath
128130
break
129131
} else {
130132
// This node is a right sibling.
131-
lChild, rChild = l.parking, n
132-
parent = t.calcParent(lChild, rChild)
133+
lChild, rChild := l.parking, n
133134

134135
// A given node is required in the proof if and only if its parent is an ancestor
135136
// of a leaf whose membership in the tree is being proven, but the given node isn't.
136-
if parent.OnProvenPath {
137-
if !lChild.OnProvenPath {
138-
t.proof = append(t.proof, lChild.value)
139-
}
140-
if !rChild.OnProvenPath {
141-
t.proof = append(t.proof, rChild.value)
142-
}
137+
if rChild.OnProvenPath && !lChild.OnProvenPath {
138+
copy := append([]byte(nil), lChild.value...)
139+
t.proof = append(t.proof, copy)
143140
}
141+
if lChild.OnProvenPath && !rChild.OnProvenPath {
142+
copy := append([]byte(nil), rChild.value...)
143+
t.proof = append(t.proof, copy)
144+
}
145+
146+
n = t.calcParent(t.parentBuf[:0], lChild, rChild)
147+
t.parentBuf = n.value
144148

145-
l.parking.value = nil
146-
n = parent
149+
l.parking.value = l.parking.value[:0]
147150
err := l.ensureNextLayerExists(t.cacheWriter)
148151
if err != nil {
149152
return err
@@ -264,18 +267,21 @@ func (t *Tree) calcEphemeralParent(parking, ephemeralNode node) (parent, lChild,
264267
default: // both are empty
265268
return EmptyNode, EmptyNode, EmptyNode
266269
}
267-
return t.calcParent(lChild, rChild), lChild, rChild
270+
return t.calcParent(nil, lChild, rChild), lChild, rChild
268271
}
269272

270-
// calcParent returns the parent node of two child nodes.
271-
func (t *Tree) calcParent(lChild, rChild node) node {
273+
// calcParent calculates the parent node of two child nodes.
274+
// The buf can be used to reuse memory for hashing.
275+
func (t *Tree) calcParent(buf []byte, lChild, rChild node) node {
272276
return node{
273-
value: t.hash(lChild.value, rChild.value),
277+
value: t.hash(buf, lChild.value, rChild.value),
274278
OnProvenPath: lChild.OnProvenPath || rChild.OnProvenPath,
275279
}
276280
}
277281

278-
func GetSha256Parent(lChild, rChild []byte) []byte {
279-
res := sha256.Sum256(append(lChild, rChild...))
280-
return res[:]
282+
func GetSha256Parent(buf, lChild, rChild []byte) []byte {
283+
hasher := sha256.New()
284+
hasher.Write(lChild)
285+
hasher.Write(rChild)
286+
return hasher.Sum(buf)
281287
}

merkle_test.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func TestNewTree(t *testing.T) {
6363
r.Equal(expectedRoot, root)
6464
}
6565

66-
func concatLeaves(lChild, rChild []byte) []byte {
66+
func concatLeaves(_, lChild, rChild []byte) []byte {
6767
if len(lChild) == NodeSize {
6868
lChild = lChild[:1]
6969
}
@@ -192,8 +192,7 @@ func TestNewTreeUnbalancedProof(t *testing.T) {
192192
expectedProof[3], _ = NewNodeFromHex("0600000000000000000000000000000000000000000000000000000000000000")
193193
expectedProof[4], _ = NewNodeFromHex("bc68417a8495de6e22d95b980fca5a1183f29eff0e2a9b7ddde91ed5bcbea952")
194194

195-
var proof nodes
196-
proof = tree.Proof()
195+
proof := tree.Proof()
197196
r.EqualValues(expectedProof, proof)
198197
}
199198

@@ -314,6 +313,38 @@ func TestNewProvingTreeMultiProof(t *testing.T) {
314313
***************************************************/
315314
}
316315

316+
// TestNewProvingTreeMultiProofReuseLeafBytes verifies if the user of Tree
317+
// can safely reuse the memory passed into Tree::AddLeaf.
318+
func TestNewProvingTreeMultiProofReuseLeafBytes(t *testing.T) {
319+
r := require.New(t)
320+
tree, err := NewProvingTree(setOf(1, 4))
321+
r.NoError(err)
322+
var leaf [32]byte
323+
for i := uint64(0); i < 8; i++ {
324+
binary.LittleEndian.PutUint64(leaf[:], i)
325+
r.NoError(tree.AddLeaf(leaf[:]))
326+
}
327+
expectedRoot, _ := NewNodeFromHex("89a0f1577268cc19b0a39c7a69f804fd140640c699585eb635ebb03c06154cce")
328+
root := tree.Root()
329+
r.Equal(expectedRoot, root)
330+
331+
expectedProof := make([][]byte, 4)
332+
expectedProof[0], _ = NewNodeFromHex("0000000000000000000000000000000000000000000000000000000000000000")
333+
expectedProof[1], _ = NewNodeFromHex("0094579cfc7b716038d416a311465309bea202baa922b224a7b08f01599642fb")
334+
expectedProof[2], _ = NewNodeFromHex("0500000000000000000000000000000000000000000000000000000000000000")
335+
expectedProof[3], _ = NewNodeFromHex("fa670379e5c2212ed93ff09769622f81f98a91e1ec8fb114d607dd25220b9088")
336+
337+
proof := tree.Proof()
338+
r.EqualValues(expectedProof, proof)
339+
340+
/***************************************************
341+
| 89a0 |
342+
| ba94 633b |
343+
| cb59 .0094. bd50 .fa67. |
344+
| .0000.=0100= 0200 0300 =0400=.0500. 0600 0700 |
345+
***************************************************/
346+
}
347+
317348
func TestNewProvingTreeMultiProof2(t *testing.T) {
318349
r := require.New(t)
319350
tree, err := NewProvingTree(setOf(0, 1, 4))
@@ -442,7 +473,7 @@ func TestTree_GetParkedNodes(t *testing.T) {
442473

443474
r.NoError(tree.AddLeaf([]byte{1}))
444475
r.EqualValues(
445-
[][]byte{nil, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")},
476+
[][]byte{{}, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")},
446477
tree.GetParkedNodes(nil))
447478

448479
r.NoError(tree.AddLeaf([]byte{2}))
@@ -452,7 +483,7 @@ func TestTree_GetParkedNodes(t *testing.T) {
452483

453484
r.NoError(tree.AddLeaf([]byte{3}))
454485
r.EqualValues(
455-
[][]byte{nil, nil, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")},
486+
[][]byte{{}, {}, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")},
456487
tree.GetParkedNodes(nil))
457488
}
458489

@@ -463,7 +494,7 @@ func TestTree_SetParkedNodes(t *testing.T) {
463494
r.NoError(err)
464495
r.NoError(tree.SetParkedNodes([][]byte{{0}}))
465496
r.NoError(tree.AddLeaf([]byte{1}))
466-
parkedNodes := [][]byte{nil, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")}
497+
parkedNodes := [][]byte{{}, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")}
467498
r.EqualValues(parkedNodes, tree.GetParkedNodes(nil))
468499

469500
tree, err = NewTreeBuilder().Build()
@@ -477,7 +508,7 @@ func TestTree_SetParkedNodes(t *testing.T) {
477508
r.NoError(err)
478509
r.NoError(tree.SetParkedNodes(parkedNodes))
479510
r.NoError(tree.AddLeaf([]byte{3}))
480-
parkedNodes = [][]byte{nil, nil, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")}
511+
parkedNodes = [][]byte{{}, {}, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")}
481512
r.EqualValues(parkedNodes, tree.GetParkedNodes(nil))
482513
}
483514

shared/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package shared
22

3-
type HashFunc func(lChild, rChild []byte) []byte
3+
type HashFunc func(buf, lChild, rChild []byte) []byte
44

55
// LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface
66
// and does not affect the LayerWriter.

validation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func (v *Validator) CalcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error
104104
subTreeSnapshots = nil
105105
}
106106
}
107-
activeNode = v.Hash(lChild, rChild)
107+
activeNode = v.Hash(nil, lChild, rChild)
108108
activePos = activePos.parent()
109109
}
110110
return activeNode, parkingSnapshots, nil

0 commit comments

Comments
 (0)