Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 228 additions & 93 deletions app/dns/cache_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,37 @@ package dns
import (
"context"
go_errors "errors"
"runtime"
"sync"
"time"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/signal/pubsub"
"github.com/xtls/xray-core/common/task"
dns_feature "github.com/xtls/xray-core/features/dns"

"golang.org/x/net/dns/dnsmessage"
"sync"
"time"
"golang.org/x/sync/singleflight"
)

const (
minSizeForEmptyRebuild = 512
shrinkAbsoluteThreshold = 10240
shrinkRatioThreshold = 0.65
migrationBatchSize = 4096
)

type CacheController struct {
sync.RWMutex
ips map[string]*record
pub *pubsub.Service
cacheCleanup *task.Periodic
name string
disableCache bool
ips map[string]*record
dirtyips map[string]*record
pub *pubsub.Service
cacheCleanup *task.Periodic
name string
disableCache bool
highWatermark int
requestGroup singleflight.Group
}

func NewCacheController(name string, disableCache bool) *CacheController {
Expand All @@ -32,139 +45,261 @@ func NewCacheController(name string, disableCache bool) *CacheController {
}

c.cacheCleanup = &task.Periodic{
Interval: time.Minute,
Interval: 300 * time.Second,
Execute: c.CacheCleanup,
}
return c
}

// CacheCleanup clears expired items from cache
func (c *CacheController) CacheCleanup() error {
expiredKeys, err := c.collectExpiredKeys()
if err != nil {
return err
}
if len(expiredKeys) == 0 {
return nil
}
c.writeAndShrink(expiredKeys)
return nil
}

func (c *CacheController) collectExpiredKeys() ([]string, error) {
c.RLock()
defer c.RUnlock()

if len(c.ips) == 0 {
return nil, errors.New("nothing to do. stopping...")
}

// skip collection if a migration is in progress
if c.dirtyips != nil {
return nil, nil
}

now := time.Now()
expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate

for domain, rec := range c.ips {
if (rec.A != nil && rec.A.Expire.Before(now)) ||
(rec.AAAA != nil && rec.AAAA.Expire.Before(now)) {
expiredKeys = append(expiredKeys, domain)
}
}

return expiredKeys, nil
}

func (c *CacheController) writeAndShrink(expiredKeys []string) {
c.Lock()
defer c.Unlock()

if len(c.ips) == 0 {
return errors.New("nothing to do. stopping...")
// double check to prevent upper call multiple cleanup tasks
if c.dirtyips != nil {
return
}

lenBefore := len(c.ips)
if lenBefore > c.highWatermark {
c.highWatermark = lenBefore
}

for domain, record := range c.ips {
if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
now := time.Now()
for _, domain := range expiredKeys {
rec := c.ips[domain]
if rec == nil {
continue
}
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
record.AAAA = nil
if rec.A != nil && rec.A.Expire.Before(now) {
rec.A = nil
}

if record.A == nil && record.AAAA == nil {
errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
if rec.AAAA != nil && rec.AAAA.Expire.Before(now) {
rec.AAAA = nil
}
if rec.A == nil && rec.AAAA == nil {
delete(c.ips, domain)
} else {
c.ips[domain] = record
}
}

if len(c.ips) == 0 {
c.ips = make(map[string]*record)
lenAfter := len(c.ips)

if lenAfter == 0 {
if c.highWatermark >= minSizeForEmptyRebuild {
errors.LogDebug(context.Background(), c.name,
" rebuilding empty cache map to reclaim memory.",
" size_before_cleanup=", lenBefore,
" peak_size_before_rebuild=", c.highWatermark,
)

c.ips = make(map[string]*record)
c.highWatermark = 0
}
return
}

if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold &&
float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold {
errors.LogDebug(context.Background(), c.name,
" shrinking cache map to reclaim memory.",
" new_size=", lenAfter,
" peak_size_before_shrink=", c.highWatermark,
" reduction_since_peak=", reductionFromPeak,
)

c.dirtyips = c.ips
c.ips = make(map[string]*record, int(float64(lenAfter)*1.1))
c.highWatermark = lenAfter
go c.migrate()
}

return nil
}

func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start)
type migrationEntry struct {
key string
value *record
}

c.Lock()
rec, found := c.ips[req.domain]
if !found {
rec = &record{}
func (c *CacheController) migrate() {
defer func() {
if r := recover(); r != nil {
errors.LogError(context.Background(), c.name, " panic during cache migration: ", r)
c.Lock()
c.dirtyips = nil
// c.ips = make(map[string]*record)
// c.highWatermark = 0
c.Unlock()
}
}()

c.RLock()
dirtyips := c.dirtyips
c.RUnlock()

// double check to prevent upper call multiple cleanup tasks
if dirtyips == nil {
return
}

switch req.reqType {
case dnsmessage.TypeA:
rec.A = ipRec
case dnsmessage.TypeAAAA:
rec.AAAA = ipRec
errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.")

batch := make([]migrationEntry, 0, migrationBatchSize)
for domain, recD := range dirtyips {
batch = append(batch, migrationEntry{domain, recD})

if len(batch) >= migrationBatchSize {
c.flush(batch)
batch = batch[:0]
runtime.Gosched()
}
}
if len(batch) > 0 {
c.flush(batch)
}

errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
c.ips[req.domain] = rec
c.Lock()
c.dirtyips = nil
c.Unlock()

switch req.reqType {
case dnsmessage.TypeA:
c.pub.Publish(req.domain+"4", nil)
if !c.disableCache {
_, _, err := rec.AAAA.getIPs()
if !go_errors.Is(err, errRecordNotFound) {
c.pub.Publish(req.domain+"6", nil)
errors.LogDebug(context.Background(), c.name, " cache migration completed.")
}

func (c *CacheController) flush(batch []migrationEntry) {
c.Lock()
defer c.Unlock()

for _, dirty := range batch {
if cur := c.ips[dirty.key]; cur != nil {
merge := &record{}
if cur.A == nil {
merge.A = dirty.value.A
} else {
merge.A = cur.A
}
}
case dnsmessage.TypeAAAA:
c.pub.Publish(req.domain+"6", nil)
if !c.disableCache {
_, _, err := rec.A.getIPs()
if !go_errors.Is(err, errRecordNotFound) {
c.pub.Publish(req.domain+"4", nil)
if cur.AAAA == nil {
merge.AAAA = dirty.value.AAAA
} else {
merge.AAAA = cur.AAAA
}
c.ips[dirty.key] = merge
} else {
c.ips[dirty.key] = dirty.value
}
}

c.Unlock()
common.Must(c.cacheCleanup.Start())
}

func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
c.RLock()
record, found := c.ips[domain]
c.RUnlock()
func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
rtt := time.Since(req.start)

if !found {
return nil, 0, errRecordNotFound
switch req.reqType {
case dnsmessage.TypeA:
c.pub.Publish(req.domain+"4", rep)
case dnsmessage.TypeAAAA:
c.pub.Publish(req.domain+"6", rep)
}

var errs []error
var allIPs []net.IP
var rTTL uint32 = dns_feature.DefaultTTL
if c.disableCache {
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt)
return
}

mergeReq := option.IPv4Enable && option.IPv6Enable
c.Lock()
lockWait := time.Since(req.start) - rtt

if option.IPv4Enable {
ips, ttl, err := record.A.getIPs()
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
return ips, ttl, err
}
if ttl < rTTL {
rTTL = ttl
}
if len(ips) > 0 {
allIPs = append(allIPs, ips...)
} else {
errs = append(errs, err)
}
newRec := &record{}
oldRec := c.ips[req.domain]
var dirtyRec *record
if c.dirtyips != nil {
dirtyRec = c.dirtyips[req.domain]
}

if option.IPv6Enable {
ips, ttl, err := record.AAAA.getIPs()
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
return ips, ttl, err
}
if ttl < rTTL {
rTTL = ttl
var pubRecord *IPRecord
var pubSuffix string

switch req.reqType {
case dnsmessage.TypeA:
newRec.A = rep
if oldRec != nil && oldRec.AAAA != nil {
newRec.AAAA = oldRec.AAAA
pubRecord = oldRec.AAAA
} else if dirtyRec != nil && dirtyRec.AAAA != nil {
pubRecord = dirtyRec.AAAA
}
if len(ips) > 0 {
allIPs = append(allIPs, ips...)
} else {
errs = append(errs, err)
pubSuffix = "6"
case dnsmessage.TypeAAAA:
newRec.AAAA = rep
if oldRec != nil && oldRec.A != nil {
newRec.A = oldRec.A
pubRecord = oldRec.A
} else if dirtyRec != nil && dirtyRec.A != nil {
pubRecord = dirtyRec.A
}
pubSuffix = "4"
}

if len(allIPs) > 0 {
return allIPs, rTTL, nil
c.ips[req.domain] = newRec
c.Unlock()

if pubRecord != nil {
_, _ /*ttl*/, err := pubRecord.getIPs()
if /*ttl >= 0 &&*/ !go_errors.Is(err, errRecordNotFound) {
c.pub.Publish(req.domain+pubSuffix, pubRecord)
}
}
if go_errors.Is(errs[0], errs[1]) {
return nil, rTTL, errs[0]

errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)

common.Must(c.cacheCleanup.Start())
}

func (c *CacheController) findRecords(domain string) *record {
c.RLock()
defer c.RUnlock()

rec := c.ips[domain]
if rec == nil && c.dirtyips != nil {
rec = c.dirtyips[domain]
}
return nil, rTTL, errors.Combine(errs...)
return rec
}

func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
Expand Down
Loading