Skip to content

Commit 4739839

Browse files
authored
feat: expose DialContext in Dialer and propagate context (#165)
* feat: expose DialContext in Dialer and propagate context A dialer only exposed the Dial method, making it impossible to correctly propagate context along with timeouts and cancellations * fix: update windows specific file to comply with interface * build: fix windows compile errors * docs: add comment about contextDialer * feat: add ConnectContext to transport client * lint: fix spelling mistake
1 parent 0d9656d commit 4739839

File tree

10 files changed

+65
-34
lines changed

10 files changed

+65
-34
lines changed

transport/client.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"errors"
2223
"fmt"
2324
"net"
@@ -88,6 +89,10 @@ func NewClientWithDialer(d Dialer, c Config, network, host string, defaultPort i
8889
}
8990

9091
func (c *Client) Connect() error {
92+
return c.ConnectContext(context.Background())
93+
}
94+
95+
func (c *Client) ConnectContext(ctx context.Context) error {
9196
c.mutex.Lock()
9297
defer c.mutex.Unlock()
9398

@@ -96,7 +101,7 @@ func (c *Client) Connect() error {
96101
c.conn = nil
97102
}
98103

99-
conn, err := c.dialer.Dial(c.network, c.host)
104+
conn, err := c.dialer.DialContext(ctx, c.network, c.host)
100105
if err != nil {
101106
return err
102107
}
@@ -217,7 +222,7 @@ func (c *Client) Test(d testing.Driver) {
217222
d.Run("logstash: "+c.host, func(d testing.Driver) {
218223
d.Run("connection", func(d testing.Driver) {
219224
netDialer := TestNetDialer(d, c.config.Timeout)
220-
_, err := netDialer.Dial("tcp", c.host)
225+
_, err := netDialer.DialContext(context.Background(), "tcp", c.host)
221226
d.Fatal("dial up", err)
222227
})
223228

@@ -227,7 +232,7 @@ func (c *Client) Test(d testing.Driver) {
227232
d.Run("TLS", func(d testing.Driver) {
228233
netDialer := NetDialer(c.config.Timeout)
229234
tlsDialer := TestTLSDialer(d, netDialer, c.config.TLS, c.config.Timeout)
230-
_, err := tlsDialer.Dial("tcp", c.host)
235+
_, err := tlsDialer.DialContext(context.Background(), "tcp", c.host)
231236
d.Fatal("dial up", err)
232237
})
233238
}

transport/dialer/dialer_windows.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package dialer
2121

2222
import (
23+
"context"
2324
"errors"
2425
"net"
2526
"strings"
@@ -60,10 +61,12 @@ func (t *NpipeDialerBuilder) String() string {
6061
func (t *NpipeDialerBuilder) Make(timeout time.Duration) (transport.Dialer, error) {
6162
to := timeout
6263
return transport.DialerFunc(
63-
func(_, _ string) (net.Conn, error) {
64-
return winio.DialPipe(
64+
func(ctx context.Context, _, _ string) (net.Conn, error) {
65+
ctx, cancel := context.WithTimeout(ctx, to)
66+
defer cancel()
67+
return winio.DialPipeContext(
68+
ctx,
6569
strings.TrimSuffix(npipe.TransformString(t.Path), "/"),
66-
&to,
6770
)
6871
},
6972
), nil

transport/httpcommon/httpcommon.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,8 @@ func (settings *HTTPTransportSettings) httpRoundTripper(
257257
opts ...TransportOption,
258258
) *http.Transport {
259259
t := http.DefaultTransport.(*http.Transport).Clone()
260-
t.DialContext = nil
261-
t.DialTLSContext = nil
262-
t.Dial = dialer.Dial //nolint:staticcheck // use deprecated function to preserve functionality
263-
t.DialTLS = tlsDialer.Dial //nolint:staticcheck // use deprecated function to preserve functionality
260+
t.DialContext = dialer.DialContext
261+
t.DialTLSContext = tlsDialer.DialContext
264262
t.TLSClientConfig = tls.ToConfig()
265263
t.ForceAttemptHTTP2 = false
266264
t.Proxy = settings.Proxy.ProxyFunc()

transport/logging.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"errors"
2223
"io"
2324
"net"
@@ -31,9 +32,9 @@ type loggingConn struct {
3132
}
3233

3334
func LoggingDialer(d Dialer, logger *logp.Logger) Dialer {
34-
return DialerFunc(func(network, addr string) (net.Conn, error) {
35+
return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
3536
logger := logger.With("network", network, "address", addr)
36-
c, err := d.Dial(network, addr)
37+
c, err := d.DialContext(ctx, network, addr)
3738
if err != nil {
3839
logger.Errorf("Error dialing %v", err)
3940
return nil, err

transport/proxy.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"net"
2223
"net/url"
2324

@@ -68,7 +69,7 @@ func ProxyDialer(log *logp.Logger, config *ProxyConfig, forward Dialer) (Dialer,
6869
}
6970

7071
log.Infof("proxy host: '%s'", url.Host)
71-
return DialerFunc(func(network, address string) (net.Conn, error) {
72+
return DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
7273
var err error
7374
var addresses []string
7475

@@ -94,6 +95,17 @@ func ProxyDialer(log *logp.Logger, config *ProxyConfig, forward Dialer) (Dialer,
9495
if err != nil {
9596
return nil, err
9697
}
97-
return DialWith(dialer, network, host, addresses, port)
98+
99+
contextDialer, ok := dialer.(Dialer)
100+
// This will never be executed because the proxy package always returns
101+
// a ContextDialer but they didn't break the interface for backward compatibility.
102+
// See golang/go#58376
103+
if !ok {
104+
contextDialer = DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
105+
return dialer.Dial(network, address)
106+
})
107+
}
108+
109+
return DialWith(ctx, contextDialer, network, host, addresses, port)
98110
}), nil
99111
}

transport/tcp.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"fmt"
2223
"net"
2324
"strings"
@@ -32,7 +33,7 @@ func NetDialer(timeout time.Duration) Dialer {
3233
}
3334

3435
func TestNetDialer(d testing.Driver, timeout time.Duration) Dialer {
35-
return DialerFunc(func(network, address string) (net.Conn, error) {
36+
return DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
3637
switch network {
3738
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
3839
default:
@@ -55,7 +56,7 @@ func TestNetDialer(d testing.Driver, timeout time.Duration) Dialer {
5556

5657
// dial via host IP by randomized iteration of known IPs
5758
dialer := &net.Dialer{Timeout: timeout}
58-
return DialWith(dialer, network, host, addresses, port)
59+
return DialWith(ctx, dialer, network, host, addresses, port)
5960
})
6061
}
6162

@@ -66,7 +67,7 @@ func UnixDialer(timeout time.Duration, sockFile string) Dialer {
6667

6768
// TestUnixDialer creates a Test Unix Dialer when using domain socket.
6869
func TestUnixDialer(d testing.Driver, timeout time.Duration, sockFile string) Dialer {
69-
return DialerFunc(func(network, address string) (net.Conn, error) {
70+
return DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
7071
d.Info("connecting using unix domain socket", sockFile)
7172
return net.DialTimeout("unix", sockFile, timeout)
7273
})

transport/tls.go

+11-9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"crypto/tls"
2223
"errors"
2324
"fmt"
@@ -44,7 +45,7 @@ func TestTLSDialer(
4445
var lastAddress string
4546
var m sync.Mutex
4647

47-
return DialerFunc(func(network, address string) (net.Conn, error) {
48+
return DialerFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
4849
switch network {
4950
case "tcp", "tcp4", "tcp6":
5051
default:
@@ -69,18 +70,18 @@ func TestTLSDialer(
6970
}
7071
m.Unlock()
7172

72-
return tlsDialWith(d, forward, network, address, timeout, tlsConfig, config)
73+
return tlsDialWith(ctx, d, forward, network, address, timeout, tlsConfig, config)
7374
})
7475
}
7576

7677
type DialerH2 interface {
77-
Dial(network, address string, cfg *tls.Config) (net.Conn, error)
78+
DialContext(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error)
7879
}
7980

80-
type DialerFuncH2 func(network, address string, cfg *tls.Config) (net.Conn, error)
81+
type DialerFuncH2 func(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error)
8182

82-
func (d DialerFuncH2) Dial(network, address string, cfg *tls.Config) (net.Conn, error) {
83-
return d(network, address, cfg)
83+
func (d DialerFuncH2) DialContext(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error) {
84+
return d(ctx, network, address, cfg)
8485
}
8586

8687
func TLSDialerH2(forward Dialer, config *tlscommon.TLSConfig, timeout time.Duration) (DialerH2, error) {
@@ -98,7 +99,7 @@ func TestTLSDialerH2(
9899
var lastAddress string
99100
var m sync.Mutex
100101

101-
return DialerFuncH2(func(network, address string, cfg *tls.Config) (net.Conn, error) {
102+
return DialerFuncH2(func(ctx context.Context, network, address string, cfg *tls.Config) (net.Conn, error) {
102103
switch network {
103104
case "tcp", "tcp4", "tcp6":
104105
default:
@@ -126,19 +127,20 @@ func TestTLSDialerH2(
126127
// NextProtos must be set from the passed h2 connection or it will fail
127128
tlsConfig.NextProtos = cfg.NextProtos
128129

129-
return tlsDialWith(d, forward, network, address, timeout, tlsConfig, config)
130+
return tlsDialWith(ctx, d, forward, network, address, timeout, tlsConfig, config)
130131
}), nil
131132
}
132133

133134
func tlsDialWith(
135+
ctx context.Context,
134136
d testing.Driver,
135137
dialer Dialer,
136138
network, address string,
137139
timeout time.Duration,
138140
tlsConfig *tls.Config,
139141
config *tlscommon.TLSConfig,
140142
) (net.Conn, error) {
141-
socket, err := dialer.Dial(network, address)
143+
socket, err := dialer.DialContext(ctx, network, address)
142144
if err != nil {
143145
return nil, err
144146
}

transport/transport.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"errors"
2223
"net"
2324

@@ -26,24 +27,29 @@ import (
2627

2728
type Dialer interface {
2829
Dial(network, address string) (net.Conn, error)
30+
DialContext(ctx context.Context, network, address string) (net.Conn, error)
2931
}
3032

31-
type DialerFunc func(network, address string) (net.Conn, error)
33+
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
3234

3335
var (
3436
ErrNotConnected = errors.New("client is not connected")
3537
)
3638

3739
func (d DialerFunc) Dial(network, address string) (net.Conn, error) {
38-
return d(network, address)
40+
return d(context.Background(), network, address)
3941
}
4042

41-
func Dial(c Config, network, address string) (net.Conn, error) {
43+
func (d DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
44+
return d(ctx, network, address)
45+
}
46+
47+
func DialContext(ctx context.Context, c Config, network, address string) (net.Conn, error) {
4248
d, err := MakeDialer(c)
4349
if err != nil {
4450
return nil, err
4551
}
46-
return d.Dial(network, address)
52+
return d.DialContext(ctx, network, address)
4753
}
4854

4955
func MakeDialer(c Config) (Dialer, error) {

transport/util.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"fmt"
2223
"math/rand"
2324
"net"
@@ -43,6 +44,7 @@ func fullAddress(host string, defaultPort int) string {
4344
//
4445
// Use this to select and dial one IP being known for one host name.
4546
func DialWith(
47+
ctx context.Context,
4648
dialer Dialer,
4749
network, host string,
4850
addresses []string,
@@ -52,7 +54,7 @@ func DialWith(
5254
case 0:
5355
return nil, fmt.Errorf("no route to host %v", host)
5456
case 1:
55-
return dialer.Dial(network, net.JoinHostPort(addresses[0], port))
57+
return dialer.DialContext(ctx, network, net.JoinHostPort(addresses[0], port))
5658
}
5759

5860
// Use randomization on DNS reported addresses combined with timeout and ACKs
@@ -69,7 +71,7 @@ func DialWith(
6971
// > "Clients, of course, may reorder this information" - with respect to
7072
// > handling order of dns records in a response.forwarded. Really required?
7173
for _, i := range rand.Perm(len(addresses)) {
72-
c, err = dialer.Dial(network, net.JoinHostPort(addresses[i], port))
74+
c, err = dialer.DialContext(ctx, network, net.JoinHostPort(addresses[i], port))
7375
if err == nil && c != nil {
7476
return c, err
7577
}

transport/wrap.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
package transport
1919

2020
import (
21+
"context"
2122
"net"
2223
)
2324

2425
func ConnWrapper(d Dialer, w func(net.Conn) net.Conn) Dialer {
25-
return DialerFunc(func(network, addr string) (net.Conn, error) {
26-
c, err := d.Dial(network, addr)
26+
return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
27+
c, err := d.DialContext(ctx, network, addr)
2728
if err != nil {
2829
return nil, err
2930
}

0 commit comments

Comments
 (0)