Skip to content

Commit 307b3dc

Browse files
authored
Merge pull request #87 from renproject/fix/dht
Fix a potential panic in dht.
2 parents 2b85600 + 2c8b0df commit 307b3dc

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

dht/table.go

+6-12
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ type Table interface {
6767
type InMemTable struct {
6868
self id.Signatory
6969

70-
sortedMu *sync.Mutex
70+
sortedMu *sync.RWMutex
7171
sorted []id.Signatory
7272

7373
addrsBySignatoryMu *sync.Mutex
@@ -86,7 +86,7 @@ func NewInMemTable(self id.Signatory) *InMemTable {
8686
return &InMemTable{
8787
self: self,
8888

89-
sortedMu: new(sync.Mutex),
89+
sortedMu: new(sync.RWMutex),
9090
sorted: []id.Signatory{},
9191

9292
addrsBySignatoryMu: new(sync.Mutex),
@@ -166,8 +166,8 @@ func (table *InMemTable) PeerAddress(peerID id.Signatory) (wire.Address, bool) {
166166

167167
// Peers returns the n closest peer IDs.
168168
func (table *InMemTable) Peers(n int) []id.Signatory {
169-
table.sortedMu.Lock()
170-
defer table.sortedMu.Unlock()
169+
table.sortedMu.RLock()
170+
defer table.sortedMu.RUnlock()
171171

172172
if n <= 0 {
173173
// For values of n that are less than, or equal to, zero, return an
@@ -183,9 +183,9 @@ func (table *InMemTable) Peers(n int) []id.Signatory {
183183

184184
// RandomPeers returns n random peer IDs
185185
func (table *InMemTable) RandomPeers(n int) []id.Signatory {
186-
table.sortedMu.Lock()
186+
table.sortedMu.RLock()
187+
defer table.sortedMu.RUnlock()
187188
m := len(table.sorted)
188-
table.sortedMu.Unlock()
189189

190190
if n <= 0 {
191191
// For values of n that are less than, or equal to, zero, return an
@@ -195,8 +195,6 @@ func (table *InMemTable) RandomPeers(n int) []id.Signatory {
195195
}
196196
if n >= m {
197197
sigs := make([]id.Signatory, m)
198-
table.sortedMu.Lock()
199-
defer table.sortedMu.Unlock()
200198
copy(sigs, table.sorted)
201199
return sigs
202200
}
@@ -208,8 +206,6 @@ func (table *InMemTable) RandomPeers(n int) []id.Signatory {
208206
if m <= 10000 || n >= m/50.0 {
209207
shuffled := make([]id.Signatory, n)
210208
indexPerm := rand.Perm(m)
211-
table.sortedMu.Lock()
212-
defer table.sortedMu.Unlock()
213209
for i := 0; i < n; i++ {
214210
shuffled[i] = table.sorted[indexPerm[i]]
215211
}
@@ -219,8 +215,6 @@ func (table *InMemTable) RandomPeers(n int) []id.Signatory {
219215
// Otherwise, use Floyd's sampling algorithm to select n random elements
220216
set := make(map[int]struct{}, n)
221217
randomSelection := make([]id.Signatory, 0, n)
222-
table.sortedMu.Lock()
223-
defer table.sortedMu.Unlock()
224218
for i := m - n; i < m; i++ {
225219
index := table.randObj.Intn(i)
226220
if _, ok := set[index]; !ok {

dht/table_test.go

+39-1
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ package dht_test
22

33
import (
44
"fmt"
5-
"github.com/renproject/aw/wire"
5+
"log"
66
"math/rand"
77
"strconv"
88
"testing/quick"
99
"time"
1010

1111
"github.com/renproject/aw/dht"
1212
"github.com/renproject/aw/dht/dhtutil"
13+
"github.com/renproject/aw/wire"
1314
"github.com/renproject/id"
1415

1516
. "github.com/onsi/ginkgo"
@@ -204,6 +205,43 @@ var _ = Describe("DHT", func() {
204205
}
205206
}
206207
})
208+
209+
It("should work while deleting peers from the table", func() {
210+
table, _ := initDHT()
211+
numAddrs := rand.Intn(100)
212+
numRandAddrs := rand.Intn(numAddrs)
213+
214+
// Insert `numAddrs` random addresses into the store.
215+
deletedPeers := make([]id.Signatory, 0, 50)
216+
for i := 0; i < numAddrs; i++ {
217+
privKey := id.NewPrivKey()
218+
sig := privKey.Signatory()
219+
addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano()))
220+
table.AddPeer(sig, addr)
221+
if i < numAddrs/2 {
222+
deletedPeers = append(deletedPeers, sig)
223+
}
224+
}
225+
226+
done := make(chan struct{}, 1)
227+
go func() {
228+
defer close(done)
229+
230+
for i := range deletedPeers{
231+
table.DeletePeer(deletedPeers[i])
232+
}
233+
}()
234+
235+
total := time.Duration(0)
236+
for i := 0; i <50 ; i ++ {
237+
start := time.Now()
238+
table.RandomPeers(numRandAddrs)
239+
duration := time.Now().Sub(start)
240+
total += duration
241+
}
242+
log.Printf("RandomPeers takes %v on average", total/50)
243+
<- done
244+
})
207245
})
208246

209247
Context("when querying the number of addresses", func() {

0 commit comments

Comments
 (0)