@@ -2,31 +2,109 @@ package exchanger
22
33import (
44 "context"
5+ "errors"
56 "fmt"
7+ "hash/maphash"
8+ "io"
9+ "math/rand/v2"
10+ "net"
611 "strings"
12+ "syscall"
713 "time"
814
915 "github.com/miekg/dns"
16+ "github.com/qdm12/dns/v2/internal/pool"
1017)
1118
1219type Exchanger struct {
13- client * dns.Client
14- dialer Dialer
15- warner Warner
20+ client * dns.Client
21+ dialer Dialer
22+ warner Warner
23+ reuseConns bool
24+ pool * pool.Pool
25+ rand * rand.Rand
26+ addresses []string
1627}
1728
18- func New (dialer Dialer , warner Warner ) * Exchanger {
29+ func New (dialer Dialer , poolMetrics PoolMetrics , warner Warner ) * Exchanger {
30+ reuseConns := dialer .ReusableConnsSupported ()
31+ addresses := dialer .Addresses ()
32+ if len (addresses ) == 0 {
33+ panic ("dialer " + dialer .String () + " has no addresses" )
34+ }
1935 return & Exchanger {
20- client : & dns.Client {},
21- dialer : dialer ,
22- warner : warner ,
36+ client : & dns.Client {},
37+ dialer : dialer ,
38+ warner : warner ,
39+ reuseConns : reuseConns ,
40+ pool : pool .New (dialer , poolMetrics ),
41+ rand : rand .New (newMaphashSource ()), //nolint:gosec
42+ addresses : addresses ,
2343 }
2444}
2545
46+ var ErrDialFailed = errors .New ("dial failed" )
47+
2648func (e * Exchanger ) Exchange (ctx context.Context , network string , request * dns.Msg ) (
2749 response * dns.Msg , err error ,
2850) {
29- netConn , err := e .dialer .Dial (ctx , network , "" )
51+ if e .reuseConns {
52+ return e .exchangeWithPool (ctx , network , request ) // dot, doh
53+ }
54+ return e .exchangeWithRand (ctx , network , request ) // plain
55+ }
56+
57+ func (e * Exchanger ) exchangeWithPool (ctx context.Context , network string , request * dns.Msg ) (
58+ response * dns.Msg , err error ,
59+ ) {
60+ netConn , err := e .pool .Get (ctx , network )
61+ if err != nil {
62+ return nil , fmt .Errorf ("getting %s connection for request %s: %w" ,
63+ e .dialer , extractRequestQuestion (request ), err )
64+ }
65+
66+ defer func () {
67+ if err != nil {
68+ e .pool .PutDead (netConn )
69+ } else {
70+ e .pool .Put (netConn )
71+ }
72+ }()
73+
74+ dnsConn := & dns.Conn {Conn : netConn }
75+ response , roundTripDuration , err := e .client .ExchangeWithConnContext (ctx , request , dnsConn )
76+ if err == nil {
77+ return response , nil
78+ }
79+ if ! isClosedConnErr (err ) {
80+ roundTripMilliseconds := roundTripDuration .Round (time .Millisecond ).Milliseconds ()
81+ return nil , fmt .Errorf ("exchanging over %s connection (%dms) for request %s: %w" ,
82+ e .dialer , roundTripMilliseconds , extractRequestQuestion (request ), err )
83+ }
84+
85+ // Connection is closed, try to renew it
86+ _ = dnsConn .Close ()
87+ netConn , err = e .pool .Renew (ctx , network , netConn )
88+ if err != nil {
89+ return nil , fmt .Errorf ("renewing %s connection for request %s: %w" ,
90+ e .dialer , extractRequestQuestion (request ), err )
91+ }
92+ dnsConn = & dns.Conn {Conn : netConn }
93+ response , roundTripDuration , err = e .client .ExchangeWithConnContext (ctx , request , dnsConn )
94+ if err != nil {
95+ roundTripMilliseconds := roundTripDuration .Round (time .Millisecond ).Milliseconds ()
96+ return nil , fmt .Errorf ("exchanging over %s connection (%dms) for request %s: %w" ,
97+ e .dialer , roundTripMilliseconds , extractRequestQuestion (request ), err )
98+ }
99+
100+ return response , nil
101+ }
102+
103+ func (e * Exchanger ) exchangeWithRand (ctx context.Context , network string , request * dns.Msg ) (
104+ response * dns.Msg , err error ,
105+ ) {
106+ addrOrURL := e .addresses [e .rand .IntN (len (e .addresses ))]
107+ netConn , err := e .dialer .Dial (ctx , network , addrOrURL )
30108 if err != nil {
31109 return nil , fmt .Errorf ("dialing %s server for request %s: %w" ,
32110 e .dialer , extractRequestQuestion (request ), err )
@@ -58,3 +136,20 @@ func extractRequestQuestion(request *dns.Msg) (s string) {
58136 dns .TypeToString [question .Qtype ] + " " +
59137 strings .ToLower (question .Name )
60138}
139+
140+ func isClosedConnErr (err error ) bool {
141+ return errors .Is (err , net .ErrClosed ) ||
142+ errors .Is (err , io .EOF ) ||
143+ errors .Is (err , syscall .EPIPE ) ||
144+ errors .Is (err , syscall .ECONNRESET )
145+ }
146+
147+ func newMaphashSource () * mapHashSource {
148+ return & mapHashSource {}
149+ }
150+
151+ type mapHashSource struct {}
152+
153+ func (s * mapHashSource ) Uint64 () uint64 {
154+ return new (maphash.Hash ).Sum64 ()
155+ }
0 commit comments