Skip to content
This repository was archived by the owner on Jan 24, 2025. It is now read-only.

Commit 1be060b

Browse files
author
torch
committed
nd: better public interface
1 parent c09df2f commit 1be060b

File tree

3 files changed

+43
-55
lines changed

3 files changed

+43
-55
lines changed

nd/nd.go

+34-51
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,9 @@ func (nd *NetDog) Close() error {
5757
return nd.dirConn.Close()
5858
}
5959

60-
func (nd *NetDog) pick(needle *[32]byte) *directory.ShortNodeInfo {
60+
func Pick(routers []*directory.ShortNodeInfo, needle *[32]byte) *directory.ShortNodeInfo {
6161
var highestNotGreater *directory.ShortNodeInfo
62-
for _, ni := range nd.routers {
63-
if !ni.Fast || !ni.Stable || !ni.Running || !ni.Valid {
64-
continue
65-
}
62+
for _, ni := range routers {
6663
if bytes.Compare(ni.ID[:], needle[:]) <= 0 &&
6764
(highestNotGreater == nil || bytes.Compare(ni.ID[:], highestNotGreater.ID[:]) > 0) {
6865
highestNotGreater = ni
@@ -71,11 +68,8 @@ func (nd *NetDog) pick(needle *[32]byte) *directory.ShortNodeInfo {
7168
if highestNotGreater != nil {
7269
return highestNotGreater
7370
}
74-
highestOverall := nd.routers[0]
75-
for _, ni := range nd.routers {
76-
if !ni.Fast || !ni.Stable || !ni.Running {
77-
continue
78-
}
71+
highestOverall := routers[0]
72+
for _, ni := range routers {
7973
if bytes.Compare(ni.ID[:], highestOverall.ID[:]) > 0 {
8074
highestOverall = ni
8175
}
@@ -85,45 +79,25 @@ func (nd *NetDog) pick(needle *[32]byte) *directory.ShortNodeInfo {
8579

8680
// connect requires len(cookie) = 20 and len(sendPayload) == 148. Nondeterminisitcally, EITHER
8781
// the sendPayload is sent to the peer OR their sendPayload is returned here, not both.
88-
func connect(ctx context.Context, tc *torch.TorConn, cookie []byte, sendPayload []byte, tcID []byte, tcNtorPublic []byte) (*torch.TorConn, *torch.Circuit, []byte, error) {
89-
type circRet struct {
90-
*torch.Circuit
91-
error
92-
}
93-
ch := make(chan circRet)
94-
mkCirc := func() {
95-
circ, err := tc.CreateCircuit(ctx, tcID, tcNtorPublic)
96-
ch <- circRet{circ, err}
97-
}
98-
go mkCirc()
99-
go mkCirc()
100-
101-
cr := <-ch
102-
if cr.error != nil {
103-
<-ch
104-
return nil, nil, nil, cr.error
105-
}
106-
circ := cr.Circuit
107-
accept, err := circ.ListenRendezvousRaw(cookie)
108-
if err == nil {
82+
func connect(ctx context.Context, c1, c2 *torch.Circuit, cookie []byte, sendPayload []byte) (*torch.Circuit, []byte, error) {
83+
if accept, err := c1.ListenRendezvousRaw(cookie); err == nil {
84+
c2.Close()
85+
10986
recvPayload, err := accept()
11087
if err != nil {
111-
return nil, nil, nil, err
88+
c1.Close()
89+
return nil, nil, err
11290
}
113-
go func() { <-ch; close(ch) }()
114-
return tc, circ, recvPayload, nil
115-
}
116-
cr = <-ch
117-
close(ch)
118-
if cr.error != nil {
119-
<-ch
120-
return nil, nil, nil, cr.error
91+
92+
return c1, recvPayload, nil
12193
}
122-
circ = cr.Circuit
123-
if err := circ.DialRendezvousRaw(cookie, sendPayload); err != nil {
124-
return nil, nil, nil, err
94+
c1.Close()
95+
96+
if err := c2.DialRendezvousRaw(cookie, sendPayload); err != nil {
97+
c2.Close()
98+
return nil, nil, err
12599
}
126-
return tc, circ, nil, nil
100+
return c2, nil, nil
127101
}
128102

129103
func ND(ctx context.Context, needle *[32]byte, seed []byte) (*Conn, error) {
@@ -133,7 +107,7 @@ func ND(ctx context.Context, needle *[32]byte, seed []byte) (*Conn, error) {
133107
}
134108
defer nd.Close()
135109

136-
mid_, err := torch.DownloadMicrodescriptors(nd.dirClient, []*directory.ShortNodeInfo{nd.pick(needle)})
110+
mid_, err := torch.DownloadMicrodescriptors(nd.dirClient, []*directory.ShortNodeInfo{Pick(nd.routers, needle)})
137111
if err != nil {
138112
return nil, err
139113
}
@@ -147,11 +121,21 @@ func ND(ctx context.Context, needle *[32]byte, seed []byte) (*Conn, error) {
147121
return nil, err
148122
}
149123

150-
return Handshake(ctx, tc, seed, mid.ID[:], mid.NTorOnionKey[:])
124+
c1, err := tc.CreateCircuit(ctx, mid.ID[:], mid.NTorOnionKey)
125+
if err != nil {
126+
return nil, err
127+
}
128+
c2, err := tc.CreateCircuit(ctx, mid.ID[:], mid.NTorOnionKey)
129+
if err != nil {
130+
return nil, err
131+
}
132+
return Handshake(ctx, c1, c2, seed)
151133
}
152134

153-
func Handshake(ctx context.Context, tc *torch.TorConn, seed, routerID, routerNTorPublic []byte) (*Conn, error) {
154-
kdf := hkdf.New(sha256.New, seed, routerID, nil)
135+
func Handshake(ctx context.Context, c1, c2 *torch.Circuit, sharedSecret []byte) (*Conn, error) {
136+
// NOTE: one of the circuits is only needed after a couple of network round
137+
// trips, so it may be a good idea to pass in a <-chan *Circuit instead.
138+
kdf := hkdf.New(sha256.New, sharedSecret, nil, nil)
155139
var cookie [20]byte
156140
var authKeyAccept, authKeyDial, continueKey [32]byte
157141
if _, err := io.ReadFull(kdf, cookie[:]); err != nil {
@@ -173,7 +157,7 @@ func Handshake(ctx context.Context, tc *torch.TorConn, seed, routerID, routerNTo
173157
var theirPK [32]byte
174158
vouchDial := secretbox.Seal(nil, pk[:], &[24]byte{}, &authKeyDial)
175159

176-
tc, circ, theirVouchDial, err := connect(ctx, tc, cookie[:], vouchDial[:], routerID, routerNTorPublic)
160+
circ, theirVouchDial, err := connect(ctx, c1, c2, cookie[:], vouchDial[:])
177161
if err != nil {
178162
return nil, err
179163
}
@@ -204,9 +188,8 @@ func Handshake(ctx context.Context, tc *torch.TorConn, seed, routerID, routerNTo
204188
var sharedDH [32]byte
205189
curve25519.ScalarMult(&sharedDH, sk, &theirPK)
206190
ret := &Conn{
207-
TorConn: tc,
208191
Circuit: circ,
209-
KDF: hkdf.New(sha256.New, append(continueKey[:], sharedDH[:]...), routerID[:], nil),
192+
KDF: hkdf.New(sha256.New, append(continueKey[:], sharedDH[:]...), nil, nil),
210193
Bit: theirVouchDial != nil,
211194
}
212195
if ret.Bit {

nd/ndconn.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
)
1111

1212
type Conn struct {
13-
TorConn *torch.TorConn
1413
Circuit *torch.Circuit
1514
Bit bool
1615
KDF io.Reader
@@ -22,7 +21,7 @@ type Conn struct {
2221
}
2322

2423
func (ndc *Conn) Close() error {
25-
err := ndc.TorConn.Close()
24+
err := ndc.Circuit.Close()
2625

2726
ndc.writeMu.Lock()
2827
defer ndc.writeMu.Unlock()

torch.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,21 @@ func niAddr(ni *directory.NodeInfo) string {
187187
return fmt.Sprintf("%d.%d.%d.%d:%d", ni.IP[0], ni.IP[1], ni.IP[2], ni.IP[3], ni.Port)
188188
}
189189

190-
func (t *Torch) Pick(weighWith func(w *directory.BandwidthWeights, n *directory.NodeInfo) int64) *directory.NodeInfo {
190+
func (t *Torch) WithDirectory(f func(*directory.Directory) interface{}) interface{} {
191191
t.RLock()
192192
defer t.RUnlock()
193+
return f(t.cachedDir)
194+
}
193195

196+
func (t *Torch) Pick(weighWith func(w *directory.BandwidthWeights, n *directory.NodeInfo) int64) *directory.NodeInfo {
194197
weigh := func(n *directory.NodeInfo) int64 {
195198
return weighWith(&t.cachedDir.Consensus.BandwidthWeights, n)
196199
}
197200

198-
return directory.Pick(weigh, t.cachedDir.Routers, nil)
201+
return t.WithDirectory(func(d *directory.Directory) interface{} {
202+
return directory.Pick(weigh, d.Routers, nil)
203+
}).(*directory.NodeInfo)
204+
199205
}
200206

201207
func weighRelayWith(w *directory.BandwidthWeights, n *directory.NodeInfo) int64 {

0 commit comments

Comments
 (0)