Skip to content

Commit cc43c3b

Browse files
committed
Session Clean ups.
1 parent de84665 commit cc43c3b

5 files changed

Lines changed: 250 additions & 13 deletions

File tree

internal/config/server.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ type ServerConfig struct {
3030
DNSRequestWorkers int `toml:"DNS_REQUEST_WORKERS"`
3131
MaxPacketSize int `toml:"MAX_PACKET_SIZE"`
3232
DropLogIntervalSecs float64 `toml:"DROP_LOG_INTERVAL_SECONDS"`
33+
SessionTimeoutSecs float64 `toml:"SESSION_TIMEOUT_SECONDS"`
34+
SessionCleanupIntervalSecs float64 `toml:"SESSION_CLEANUP_INTERVAL_SECONDS"`
35+
ClosedSessionRetentionSecs float64 `toml:"CLOSED_SESSION_RETENTION_SECONDS"`
3336
Domain []string `toml:"DOMAIN"`
3437
MinVPNLabelLength int `toml:"MIN_VPN_LABEL_LENGTH"`
3538
SupportedUploadCompressionTypes []int `toml:"SUPPORTED_UPLOAD_COMPRESSION_TYPES"`
@@ -53,6 +56,9 @@ func defaultServerConfig() ServerConfig {
5356
DNSRequestWorkers: workers,
5457
MaxPacketSize: 65535,
5558
DropLogIntervalSecs: 2.0,
59+
SessionTimeoutSecs: 300.0,
60+
SessionCleanupIntervalSecs: 30.0,
61+
ClosedSessionRetentionSecs: 600.0,
5662
Domain: nil,
5763
MinVPNLabelLength: 3,
5864
SupportedUploadCompressionTypes: []int{0, 3},
@@ -112,6 +118,15 @@ func LoadServerConfig(filename string) (ServerConfig, error) {
112118
if cfg.DropLogIntervalSecs <= 0 {
113119
cfg.DropLogIntervalSecs = 2.0
114120
}
121+
if cfg.SessionTimeoutSecs <= 0 {
122+
cfg.SessionTimeoutSecs = 300.0
123+
}
124+
if cfg.SessionCleanupIntervalSecs <= 0 {
125+
cfg.SessionCleanupIntervalSecs = 30.0
126+
}
127+
if cfg.ClosedSessionRetentionSecs <= 0 {
128+
cfg.ClosedSessionRetentionSecs = 600.0
129+
}
115130

116131
if cfg.MinVPNLabelLength <= 0 {
117132
cfg.MinVPNLabelLength = 3
@@ -142,6 +157,18 @@ func (c ServerConfig) DropLogInterval() time.Duration {
142157
return time.Duration(c.DropLogIntervalSecs * float64(time.Second))
143158
}
144159

160+
func (c ServerConfig) SessionTimeout() time.Duration {
161+
return time.Duration(c.SessionTimeoutSecs * float64(time.Second))
162+
}
163+
164+
func (c ServerConfig) SessionCleanupInterval() time.Duration {
165+
return time.Duration(c.SessionCleanupIntervalSecs * float64(time.Second))
166+
}
167+
168+
func (c ServerConfig) ClosedSessionRetention() time.Duration {
169+
return time.Duration(c.ClosedSessionRetentionSecs * float64(time.Second))
170+
}
171+
145172
func (c ServerConfig) EncryptionKeyPath() string {
146173
if c.EncryptionKeyFile == "" {
147174
return filepath.Join(c.ConfigDir, "encrypt_key.txt")

internal/udpserver/server.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net"
1616
"sync"
1717
"sync/atomic"
18+
"time"
1819

1920
"masterdnsvpn-go/internal/compression"
2021
"masterdnsvpn-go/internal/config"
@@ -71,6 +72,9 @@ func New(cfg config.ServerConfig, log *logger.Logger, codec *security.Codec) *Se
7172
}
7273

7374
func (s *Server) Run(ctx context.Context) error {
75+
runCtx, cancel := context.WithCancel(ctx)
76+
defer cancel()
77+
7478
conn, err := net.ListenUDP("udp", &net.UDPAddr{
7579
IP: net.ParseIP(s.cfg.UDPHost),
7680
Port: s.cfg.UDPPort,
@@ -97,17 +101,23 @@ func (s *Server) Run(ctx context.Context) error {
97101

98102
reqCh := make(chan request, s.cfg.MaxConcurrentRequests)
99103
var workerWG sync.WaitGroup
104+
cleanupDone := make(chan struct{})
105+
106+
go func() {
107+
defer close(cleanupDone)
108+
s.sessionCleanupLoop(runCtx)
109+
}()
100110

101111
for i := range s.cfg.DNSRequestWorkers {
102112
workerWG.Add(1)
103113
go func(workerID int) {
104114
defer workerWG.Done()
105-
s.worker(ctx, conn, reqCh, workerID)
115+
s.worker(runCtx, conn, reqCh, workerID)
106116
}(i + 1)
107117
}
108118

109119
go func() {
110-
<-ctx.Done()
120+
<-runCtx.Done()
111121
_ = conn.Close()
112122
}()
113123

@@ -117,7 +127,7 @@ func (s *Server) Run(ctx context.Context) error {
117127
readerWG.Add(1)
118128
go func(readerID int) {
119129
defer readerWG.Done()
120-
if err := s.readLoop(ctx, conn, reqCh, readerID); err != nil {
130+
if err := s.readLoop(runCtx, conn, reqCh, readerID); err != nil {
121131
select {
122132
case readErrCh <- err:
123133
default:
@@ -129,6 +139,8 @@ func (s *Server) Run(ctx context.Context) error {
129139
readerWG.Wait()
130140
close(reqCh)
131141
workerWG.Wait()
142+
cancel()
143+
<-cleanupDone
132144

133145
if ctx.Err() != nil {
134146
return ctx.Err()
@@ -142,6 +154,32 @@ func (s *Server) Run(ctx context.Context) error {
142154
}
143155
}
144156

157+
func (s *Server) sessionCleanupLoop(ctx context.Context) {
158+
interval := s.cfg.SessionCleanupInterval()
159+
if interval <= 0 {
160+
interval = 30 * time.Second
161+
}
162+
163+
ticker := time.NewTicker(interval)
164+
defer ticker.Stop()
165+
166+
for {
167+
select {
168+
case <-ctx.Done():
169+
return
170+
case now := <-ticker.C:
171+
expired := s.sessions.Cleanup(now, s.cfg.SessionTimeout(), s.cfg.ClosedSessionRetention())
172+
if len(expired) == 0 {
173+
continue
174+
}
175+
s.log.Infof(
176+
"🧹 <green>Expired Sessions Cleaned</green> <magenta>|</magenta> <blue>Count</blue>: <cyan>%d</cyan>",
177+
len(expired),
178+
)
179+
}
180+
}
181+
}
182+
145183
func (s *Server) readLoop(ctx context.Context, conn *net.UDPConn, reqCh chan<- request, readerID int) error {
146184
for {
147185
buffer := s.packetPool.Get().([]byte)

internal/udpserver/server_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,59 @@ func TestSessionStoreExpiresReuseSignatureWithoutDroppingSession(t *testing.T) {
312312
store.mu.Unlock()
313313
}
314314

315+
func TestSessionStoreCleanupMovesExpiredSessionToRecentClosed(t *testing.T) {
316+
store := newSessionStore()
317+
payload := []byte{1, 0x21, 0x00, 0x96, 0x00, 0xC8, 0x44, 0x33, 0x22, 0x11}
318+
319+
record, reused, err := store.findOrCreate(payload, 3, 0)
320+
if err != nil {
321+
t.Fatalf("findOrCreate returned error: %v", err)
322+
}
323+
if reused || record == nil {
324+
t.Fatal("expected a new session record")
325+
}
326+
327+
store.mu.Lock()
328+
record.LastActivityAt = time.Now().Add(-2 * time.Minute)
329+
expectedCookie := record.Cookie
330+
store.mu.Unlock()
331+
332+
expired := store.Cleanup(time.Now(), time.Minute, 10*time.Minute)
333+
if len(expired) != 1 || expired[0] != record.ID {
334+
t.Fatalf("unexpected expired sessions: %#v", expired)
335+
}
336+
if _, ok := store.Active(record.ID); ok {
337+
t.Fatal("expired session should no longer be active")
338+
}
339+
if cookie, ok := store.ExpectedCookie(record.ID); !ok || cookie != expectedCookie {
340+
t.Fatalf("recently closed cookie missing: ok=%v cookie=%d expected=%d", ok, cookie, expectedCookie)
341+
}
342+
}
343+
344+
func TestSessionStoreTouchRefreshesActivity(t *testing.T) {
345+
store := newSessionStore()
346+
payload := []byte{1, 0x21, 0x00, 0x96, 0x00, 0xC8, 0x44, 0x33, 0x22, 0x11}
347+
348+
record, _, err := store.findOrCreate(payload, 0, 0)
349+
if err != nil {
350+
t.Fatalf("findOrCreate returned error: %v", err)
351+
}
352+
353+
old := record.LastActivityAt
354+
time.Sleep(5 * time.Millisecond)
355+
if !store.Touch(record.ID, time.Now()) {
356+
t.Fatal("Touch returned false")
357+
}
358+
359+
active, ok := store.Active(record.ID)
360+
if !ok {
361+
t.Fatal("Active returned false")
362+
}
363+
if !active.LastActivityAt.After(old) {
364+
t.Fatal("last activity timestamp was not updated")
365+
}
366+
}
367+
315368
func buildServerTestQuery(id uint16, name string, qtype uint16) []byte {
316369
qname := encodeServerTestName(name)
317370
packet := make([]byte, 12+len(qname)+4)

internal/udpserver/session.go

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,27 @@ type sessionRecord struct {
3838
VerifyCode [4]byte
3939
Signature [sessionInitDataSize]byte
4040
CreatedAt time.Time
41+
LastActivityAt time.Time
4142
ReuseUntil time.Time
4243
}
4344

45+
type closedSessionRecord struct {
46+
Cookie uint8
47+
ExpiresAt time.Time
48+
}
49+
4450
type sessionStore struct {
45-
mu sync.Mutex
46-
nextID uint16
47-
byID [maxServerSessions]*sessionRecord
48-
bySig map[[sessionInitDataSize]byte]uint8
51+
mu sync.Mutex
52+
nextID uint16
53+
byID [maxServerSessions]*sessionRecord
54+
bySig map[[sessionInitDataSize]byte]uint8
55+
recentClosed map[uint8]closedSessionRecord
4956
}
5057

5158
func newSessionStore() *sessionStore {
5259
return &sessionStore{
53-
bySig: make(map[[sessionInitDataSize]byte]uint8, 64),
60+
bySig: make(map[[sessionInitDataSize]byte]uint8, 64),
61+
recentClosed: make(map[uint8]closedSessionRecord, 32),
5462
}
5563
}
5664

@@ -71,6 +79,7 @@ func (s *sessionStore) findOrCreate(payload []byte, uploadCompressionType uint8,
7179
if sessionID, ok := s.bySig[signature]; ok {
7280
if existing := s.byID[sessionID]; existing != nil {
7381
if now.Before(existing.ReuseUntil) || now.Equal(existing.ReuseUntil) {
82+
existing.LastActivityAt = now
7483
return existing, true, nil
7584
}
7685
}
@@ -83,11 +92,12 @@ func (s *sessionStore) findOrCreate(payload []byte, uploadCompressionType uint8,
8392
}
8493

8594
record := &sessionRecord{
86-
ID: uint8(slot),
87-
ResponseMode: payload[0],
88-
CreatedAt: now,
89-
ReuseUntil: now.Add(sessionInitTTL),
90-
Signature: signature,
95+
ID: uint8(slot),
96+
ResponseMode: payload[0],
97+
CreatedAt: now,
98+
LastActivityAt: now,
99+
ReuseUntil: now.Add(sessionInitTTL),
100+
Signature: signature,
91101
}
92102
record.UploadCompression = compression.NormalizeType(uploadCompressionType)
93103
record.DownloadCompression = compression.NormalizeType(downloadCompressionType)
@@ -98,6 +108,7 @@ func (s *sessionStore) findOrCreate(payload []byte, uploadCompressionType uint8,
98108

99109
s.byID[slot] = record
100110
s.bySig[signature] = uint8(slot)
111+
delete(s.recentClosed, uint8(slot))
101112
s.nextID = uint16((slot + 1) % maxServerSessions)
102113
return record, false, nil
103114
}
@@ -111,6 +122,109 @@ func (s *sessionStore) expireReuseLocked(now time.Time) {
111122
}
112123
}
113124

125+
func (s *sessionStore) Touch(sessionID uint8, now time.Time) bool {
126+
s.mu.Lock()
127+
defer s.mu.Unlock()
128+
129+
record := s.byID[sessionID]
130+
if record == nil {
131+
return false
132+
}
133+
record.LastActivityAt = now
134+
return true
135+
}
136+
137+
func (s *sessionStore) Active(sessionID uint8) (*sessionRecord, bool) {
138+
s.mu.Lock()
139+
defer s.mu.Unlock()
140+
141+
record := s.byID[sessionID]
142+
if record == nil {
143+
return nil, false
144+
}
145+
copyRecord := *record
146+
return &copyRecord, true
147+
}
148+
149+
func (s *sessionStore) ExpectedCookie(sessionID uint8) (uint8, bool) {
150+
s.mu.Lock()
151+
defer s.mu.Unlock()
152+
153+
if record := s.byID[sessionID]; record != nil {
154+
return record.Cookie, true
155+
}
156+
if record, ok := s.recentClosed[sessionID]; ok {
157+
return record.Cookie, true
158+
}
159+
return 0, false
160+
}
161+
162+
func (s *sessionStore) ValidateCookie(sessionID uint8, cookie uint8) bool {
163+
expected, ok := s.ExpectedCookie(sessionID)
164+
return ok && expected == cookie
165+
}
166+
167+
func (s *sessionStore) Close(sessionID uint8, now time.Time, retention time.Duration) bool {
168+
s.mu.Lock()
169+
defer s.mu.Unlock()
170+
171+
record := s.byID[sessionID]
172+
if record == nil {
173+
return false
174+
}
175+
176+
delete(s.bySig, record.Signature)
177+
s.byID[sessionID] = nil
178+
if retention > 0 {
179+
s.recentClosed[sessionID] = closedSessionRecord{
180+
Cookie: record.Cookie,
181+
ExpiresAt: now.Add(retention),
182+
}
183+
} else {
184+
delete(s.recentClosed, sessionID)
185+
}
186+
return true
187+
}
188+
189+
func (s *sessionStore) Cleanup(now time.Time, idleTimeout time.Duration, closedRetention time.Duration) []uint8 {
190+
s.mu.Lock()
191+
defer s.mu.Unlock()
192+
193+
s.expireReuseLocked(now)
194+
195+
for sessionID, record := range s.recentClosed {
196+
if !now.Before(record.ExpiresAt) {
197+
delete(s.recentClosed, sessionID)
198+
}
199+
}
200+
201+
if idleTimeout <= 0 {
202+
return nil
203+
}
204+
205+
expired := make([]uint8, 0, 8)
206+
for sessionID, record := range s.byID {
207+
if record == nil {
208+
continue
209+
}
210+
if now.Sub(record.LastActivityAt) < idleTimeout {
211+
continue
212+
}
213+
214+
delete(s.bySig, record.Signature)
215+
s.byID[sessionID] = nil
216+
if closedRetention > 0 {
217+
s.recentClosed[uint8(sessionID)] = closedSessionRecord{
218+
Cookie: record.Cookie,
219+
ExpiresAt: now.Add(closedRetention),
220+
}
221+
}
222+
expired = append(expired, uint8(sessionID))
223+
}
224+
225+
return expired
226+
}
227+
114228
func (s *sessionStore) allocateSlotLocked() int {
115229
for i := range maxServerSessions {
116230
slot := int((s.nextID + uint16(i)) % maxServerSessions)

server_config.toml.simple

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ MAX_PACKET_SIZE = 65535
3131
# Minimum interval between overload/drop warning logs.
3232
DROP_LOG_INTERVAL_SECONDS = 2.0
3333

34+
# Active session lifecycle.
35+
SESSION_TIMEOUT_SECONDS = 300.0
36+
SESSION_CLEANUP_INTERVAL_SECONDS = 30.0
37+
CLOSED_SESSION_RETENTION_SECONDS = 600.0
38+
3439
# Allowed tunnel domains.
3540
DOMAIN = ["v.domain.com"]
3641

0 commit comments

Comments
 (0)