Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ jobs:
touch coverage.txt
docker run --rm \
-v "$(pwd)/coverage.txt:/tmp/gobuild/coverage.txt" \
test-container
test-container \
go test -race -coverprofile=coverage.txt ./...

- name: Run fuzz tests in test container
run: |
docker run --rm \
test-container \
go test -fuzz=Fuzz_Pool_compact -fuzztime 10s ./internal/pool

- name: Build final image
run: docker build -t final-image .
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/configs/mlc-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"fallbackRetryDelay": "30s",
"aliveStatusCodes": [
200,
403,
429
]
}
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.vscode
.vscode
testdata/
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ FROM --platform=${BUILDPLATFORM} base AS test
# - we set CGO_ENABLED=1 to have it enabled
# - we installed g++ to support the race detector
ENV CGO_ENABLED=1
ENTRYPOINT go test -race -coverprofile=coverage.txt ./...

FROM --platform=${BUILDPLATFORM} base AS lint
COPY .golangci.yml ./
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ For example, the environment variable `UPSTREAM_TYPE` corresponds to the CLI fla
| `BLOCK_MALICIOUS` | `on` | `on` or `off`, to block malicious IP addresses and malicious hostnames from being resolved |
| `BLOCK_SURVEILLANCE` | `off` | `on` or `off`, to block surveillance IP addresses and hostnames from being resolved |
| `BLOCK_ADS` | `off` | `on` or `off`, to block ads IP addresses and hostnames from being resolved |
| `BLOCK_HOSTNAMES` | | comma separated list of hostnames to block from being resolved |
| `BLOCK_HOSTNAMES` | | comma separated list of hostnames to block from being resolved |
| `ALLOWED_HOSTNAMES` | | comma separated list of hostnames to leave unblocked |
| `ALLOWED_IPS` | | comma separated list of IP addresses to leave unblocked |
| `ALLOWED_CIDRS` | | comma separated list of IP networks (CIDRs) to leave unblocked |
| `BLOCK_IPS` | | comma separated list of IPs to block from being returned to clients |
| `BLOCK_CIDRS` | | comma separated list of IP networks (CIDRs) to block from being returned to clients |
| `BLOCK_IPS` | | comma separated list of IPs to block from being returned to clients |
| `BLOCK_CIDRS` | | comma separated list of IP networks (CIDRs) to block from being returned to clients |
| `REBINDING_PROTECTION_EXEMPT_HOSTNAMES` | | comma separated list of hostnames to exempt from DNS rebinding protection |
| `LOG_LEVEL` | `info` | `debug`, `info`, `warning` or `error` |
| `LOG_CALLER` | `hidden` | `hidden` or `short` |
Expand Down
108 changes: 100 additions & 8 deletions internal/exchanger/exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,106 @@ package exchanger

import (
"context"
"errors"
"fmt"
"hash/maphash"
"io"
"math/rand/v2"
"net"
"strings"
"syscall"
"time"

"github.com/miekg/dns"
"github.com/qdm12/dns/v2/internal/pool"
)

type Exchanger struct {
client *dns.Client
dialer Dialer
warner Warner
client *dns.Client
dialer Dialer
warner Warner
reuseConns bool
pool *pool.Pool
rand *rand.Rand
addresses []string
}

func New(dialer Dialer, warner Warner) *Exchanger {
func New(dialer Dialer, poolMetrics PoolMetrics, warner Warner) *Exchanger {
reuseConns := dialer.ReusableConnsSupported()
addresses := dialer.Addresses()
if len(addresses) == 0 {
panic("dialer " + dialer.String() + " has no addresses")
}
return &Exchanger{
client: &dns.Client{},
dialer: dialer,
warner: warner,
client: &dns.Client{},
dialer: dialer,
warner: warner,
reuseConns: reuseConns,
pool: pool.New(dialer, poolMetrics),
rand: rand.New(newMaphashSource()), //nolint:gosec
addresses: addresses,
}
}

var ErrDialFailed = errors.New("dial failed")

func (e *Exchanger) Exchange(ctx context.Context, network string, request *dns.Msg) (
response *dns.Msg, err error,
) {
netConn, err := e.dialer.Dial(ctx, network, "")
if e.reuseConns {
return e.exchangeWithPool(ctx, network, request) // dot, doh
}
return e.exchangeWithRand(ctx, network, request) // plain
}

func (e *Exchanger) exchangeWithPool(ctx context.Context, network string, request *dns.Msg) (
response *dns.Msg, err error,
) {
netConn, err := e.pool.Get(ctx, network)
if err != nil {
return nil, fmt.Errorf("getting %s connection for request %s: %w",
e.dialer, extractRequestQuestion(request), err)
}

dnsConn := &dns.Conn{Conn: netConn}
response, roundTripDuration, err := e.client.ExchangeWithConnContext(ctx, request, dnsConn)
if err == nil {
e.pool.Put(netConn)
return response, nil
}

if !isClosedConnErr(err) {
e.pool.PutDead(netConn)
roundTripMilliseconds := roundTripDuration.Round(time.Millisecond).Milliseconds()
return nil, fmt.Errorf("exchanging over %s connection (%dms) for request %s: %w",
e.dialer, roundTripMilliseconds, extractRequestQuestion(request), err)
}

// Connection is closed, try to renew it
_ = dnsConn.Close()
netConn, err = e.pool.Renew(ctx, network, netConn)
if err != nil {
return nil, fmt.Errorf("renewing %s connection for request %s: %w",
e.dialer, extractRequestQuestion(request), err)
}
dnsConn = &dns.Conn{Conn: netConn}
response, roundTripDuration, err = e.client.ExchangeWithConnContext(ctx, request, dnsConn)
if err != nil {
e.pool.PutDead(netConn)
roundTripMilliseconds := roundTripDuration.Round(time.Millisecond).Milliseconds()
return nil, fmt.Errorf("exchanging over %s connection (%dms) for request %s: %w",
e.dialer, roundTripMilliseconds, extractRequestQuestion(request), err)
}

e.pool.Put(netConn)
return response, nil
}

func (e *Exchanger) exchangeWithRand(ctx context.Context, network string, request *dns.Msg) (
response *dns.Msg, err error,
) {
addrOrURL := e.addresses[e.rand.IntN(len(e.addresses))]
netConn, err := e.dialer.Dial(ctx, network, addrOrURL)
if err != nil {
return nil, fmt.Errorf("dialing %s server for request %s: %w",
e.dialer, extractRequestQuestion(request), err)
Expand Down Expand Up @@ -58,3 +133,20 @@ func extractRequestQuestion(request *dns.Msg) (s string) {
dns.TypeToString[question.Qtype] + " " +
strings.ToLower(question.Name)
}

func isClosedConnErr(err error) bool {
return errors.Is(err, net.ErrClosed) ||
errors.Is(err, io.EOF) ||
errors.Is(err, syscall.EPIPE) ||
errors.Is(err, syscall.ECONNRESET)
}

func newMaphashSource() *mapHashSource {
return &mapHashSource{}
}

type mapHashSource struct{}

func (s *mapHashSource) Uint64() uint64 {
return new(maphash.Hash).Sum64()
}
Loading
Loading