Skip to content

Commit 9eb0b91

Browse files
committed
feat: implement proper session restore
1 parent 5bc1368 commit 9eb0b91

2 files changed

Lines changed: 268 additions & 9 deletions

File tree

main.go

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88
mathrand "math/rand"
99
"net/http"
1010
"os"
11+
"os/signal"
1112
"strings"
1213
"sync"
14+
"syscall"
1315
"time"
1416
"unicode/utf8"
1517

@@ -348,20 +350,23 @@ type Suggestion struct {
348350

349351
// Server is the main WebSocket server
350352
type Server struct {
351-
rooms map[string]*Room
352-
sessions map[string]*Session // sessionToken -> Session
353-
clients map[*Client]bool
354-
upgrader websocket.Upgrader
355-
mu sync.RWMutex
356-
logger *zap.Logger
357-
rng *mathrand.Rand
353+
rooms map[string]*Room
354+
sessions map[string]*Session // sessionToken -> Session
355+
clients map[*Client]bool
356+
upgrader websocket.Upgrader
357+
mu sync.RWMutex
358+
logger *zap.Logger
359+
rng *mathrand.Rand
360+
startTime time.Time // Track when server started for room retention logic
358361
}
359362

360363
const (
361364
// Grace period for reconnection (increased from 5 to 15 minutes for better recovery)
362365
ReconnectGracePeriod = 15 * time.Minute
363366
// How often to clean up expired sessions
364367
SessionCleanupInterval = 1 * time.Minute
368+
// Minimum time to keep empty rooms after server restart (for reconnection)
369+
MinRoomRetentionAfterRestart = 2 * time.Minute
365370
// Security limits
366371
MaxUsernameLength = 50
367372
MaxRoomCodeLength = 10
@@ -385,8 +390,9 @@ func NewServer(logger *zap.Logger) *Server {
385390
ReadBufferSize: 4096,
386391
WriteBufferSize: 4096,
387392
},
388-
logger: logger,
389-
rng: mathrand.New(mathrand.NewSource(time.Now().UnixNano())),
393+
logger: logger,
394+
rng: mathrand.New(mathrand.NewSource(time.Now().UnixNano())),
395+
startTime: time.Now(),
390396
}
391397

392398
// Start cleanup goroutines
@@ -402,6 +408,7 @@ func (s *Server) cleanupExpiredSessions() {
402408
for range ticker.C {
403409
s.mu.Lock()
404410
now := time.Now()
411+
minRetentionTime := s.startTime.Add(MinRoomRetentionAfterRestart)
405412
expiredTokens := make([]string, 0)
406413

407414
for token, session := range s.sessions {
@@ -431,6 +438,18 @@ func (s *Server) cleanupExpiredSessions() {
431438
}
432439
room.State.Users = newUsers
433440

441+
// Check if room should be deleted (no active clients, no disconnected users)
442+
// Only delete if we're past the minimum retention time after server start
443+
if len(room.Clients) == 0 && len(room.DisconnectedUsers) == 0 {
444+
if now.After(minRetentionTime) {
445+
room.mu.Unlock()
446+
delete(s.rooms, session.RoomCode)
447+
s.logger.Info("Deleted empty room",
448+
zap.String("room_code", session.RoomCode))
449+
continue
450+
}
451+
}
452+
434453
// Notify remaining users
435454
for _, client := range room.Clients {
436455
if client != nil {
@@ -2032,6 +2051,26 @@ func main() {
20322051

20332052
server := NewServer(logger)
20342053

2054+
// Load previous state if exists
2055+
if err := server.LoadState(); err != nil {
2056+
logger.Error("Failed to load previous state", zap.Error(err))
2057+
// Continue anyway - not fatal
2058+
}
2059+
2060+
// Set up graceful shutdown
2061+
shutdown := make(chan os.Signal, 1)
2062+
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
2063+
2064+
go func() {
2065+
<-shutdown
2066+
logger.Info("Shutdown signal received, saving state...")
2067+
if err := server.SaveState(); err != nil {
2068+
logger.Error("Failed to save state", zap.Error(err))
2069+
}
2070+
logger.Info("State saved, shutting down")
2071+
os.Exit(0)
2072+
}()
2073+
20352074
http.HandleFunc("/ws", server.handleWebSocket)
20362075
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
20372076
w.Header().Set("Content-Type", "application/json")

persistence.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"time"
8+
9+
"go.uber.org/zap"
10+
)
11+
12+
const StateFile = "server_state.json"
13+
14+
// PersistentState contains all data that needs to be saved across server restarts
15+
type PersistentState struct {
16+
ServerShutdownTime time.Time `json:"server_shutdown_time"`
17+
Rooms []PersistentRoom `json:"rooms"`
18+
Sessions []PersistentSession `json:"sessions"`
19+
}
20+
21+
// PersistentRoom is a serializable version of Room
22+
type PersistentRoom struct {
23+
Code string `json:"code"`
24+
HostID string `json:"host_id"`
25+
State *RoomState `json:"state"`
26+
DisconnectedUsers map[string]*Session `json:"disconnected_users"`
27+
PendingSuggestions []PersistentSuggestion `json:"pending_suggestions"`
28+
HostDisconnectedAt *time.Time `json:"host_disconnected_at,omitempty"`
29+
}
30+
31+
// PersistentSuggestion is a serializable version of Suggestion
32+
type PersistentSuggestion struct {
33+
ID string `json:"id"`
34+
FromUserID string `json:"from_user_id"`
35+
FromUsername string `json:"from_username"`
36+
Track *TrackInfo `json:"track"`
37+
}
38+
39+
// PersistentSession is a serializable version of Session with token
40+
type PersistentSession struct {
41+
Token string `json:"token"`
42+
UserID string `json:"user_id"`
43+
Username string `json:"username"`
44+
RoomCode string `json:"room_code"`
45+
IsHost bool `json:"is_host"`
46+
DisconnectAt time.Time `json:"disconnect_at"`
47+
}
48+
49+
// SaveState saves the current server state to disk
50+
func (s *Server) SaveState() error {
51+
s.mu.RLock()
52+
defer s.mu.RUnlock()
53+
54+
state := PersistentState{
55+
ServerShutdownTime: time.Now(),
56+
Rooms: make([]PersistentRoom, 0),
57+
Sessions: make([]PersistentSession, 0),
58+
}
59+
60+
// Save all rooms
61+
for _, room := range s.rooms {
62+
room.mu.RLock()
63+
64+
// Convert pending suggestions
65+
pendingSuggestions := make([]PersistentSuggestion, 0, len(room.PendingSuggestions))
66+
for _, suggestion := range room.PendingSuggestions {
67+
pendingSuggestions = append(pendingSuggestions, PersistentSuggestion{
68+
ID: suggestion.ID,
69+
FromUserID: suggestion.FromUserID,
70+
FromUsername: suggestion.FromUsername,
71+
Track: suggestion.Track,
72+
})
73+
}
74+
75+
persistentRoom := PersistentRoom{
76+
Code: room.Code,
77+
HostID: room.Host.ID,
78+
State: room.State,
79+
DisconnectedUsers: room.DisconnectedUsers,
80+
PendingSuggestions: pendingSuggestions,
81+
HostDisconnectedAt: room.HostDisconnectedAt,
82+
}
83+
84+
state.Rooms = append(state.Rooms, persistentRoom)
85+
room.mu.RUnlock()
86+
}
87+
88+
// Save all sessions
89+
for token, session := range s.sessions {
90+
state.Sessions = append(state.Sessions, PersistentSession{
91+
Token: token,
92+
UserID: session.UserID,
93+
Username: session.Username,
94+
RoomCode: session.RoomCode,
95+
IsHost: session.IsHost,
96+
DisconnectAt: session.DisconnectAt,
97+
})
98+
}
99+
100+
// Marshal to JSON
101+
data, err := json.MarshalIndent(state, "", " ")
102+
if err != nil {
103+
return fmt.Errorf("marshal state: %w", err)
104+
}
105+
106+
// Write to file
107+
if err := os.WriteFile(StateFile, data, 0644); err != nil {
108+
return fmt.Errorf("write state file: %w", err)
109+
}
110+
111+
s.logger.Info("Server state saved",
112+
zap.Int("rooms", len(state.Rooms)),
113+
zap.Int("sessions", len(state.Sessions)))
114+
115+
return nil
116+
}
117+
118+
// LoadState loads the server state from disk
119+
func (s *Server) LoadState() error {
120+
// Check if state file exists
121+
if _, err := os.Stat(StateFile); os.IsNotExist(err) {
122+
s.logger.Info("No previous state file found, starting fresh")
123+
return nil
124+
}
125+
126+
// Read state file
127+
data, err := os.ReadFile(StateFile)
128+
if err != nil {
129+
return fmt.Errorf("read state file: %w", err)
130+
}
131+
132+
var state PersistentState
133+
if err := json.Unmarshal(data, &state); err != nil {
134+
return fmt.Errorf("unmarshal state: %w", err)
135+
}
136+
137+
// Calculate time elapsed since shutdown
138+
shutdownDuration := time.Since(state.ServerShutdownTime)
139+
s.logger.Info("Loading previous state",
140+
zap.Duration("offline_duration", shutdownDuration),
141+
zap.Int("rooms", len(state.Rooms)),
142+
zap.Int("sessions", len(state.Sessions)))
143+
144+
// Restore rooms
145+
for _, persistentRoom := range state.Rooms {
146+
room := &Room{
147+
Code: persistentRoom.Code,
148+
Host: nil, // Will be nil initially - user needs to reconnect
149+
Clients: make(map[string]*Client),
150+
PendingJoins: make(map[string]*Client),
151+
PendingSuggestions: make(map[string]*Suggestion),
152+
DisconnectedUsers: persistentRoom.DisconnectedUsers,
153+
State: persistentRoom.State,
154+
BufferingUsers: make(map[string]bool),
155+
HostDisconnectedAt: persistentRoom.HostDisconnectedAt,
156+
}
157+
158+
// Restore pending suggestions
159+
for _, ps := range persistentRoom.PendingSuggestions {
160+
room.PendingSuggestions[ps.ID] = &Suggestion{
161+
ID: ps.ID,
162+
FromUserID: ps.FromUserID,
163+
FromUsername: ps.FromUsername,
164+
Track: ps.Track,
165+
}
166+
}
167+
168+
// Update disconnect times for all users to account for shutdown duration
169+
for _, session := range room.DisconnectedUsers {
170+
session.DisconnectAt = session.DisconnectAt.Add(shutdownDuration)
171+
}
172+
173+
// Update host disconnected time if applicable
174+
if room.HostDisconnectedAt != nil {
175+
newTime := room.HostDisconnectedAt.Add(shutdownDuration)
176+
room.HostDisconnectedAt = &newTime
177+
}
178+
179+
// Find and set the host reference
180+
hostSession, exists := room.DisconnectedUsers[persistentRoom.HostID]
181+
if exists {
182+
// Create a placeholder client for the host
183+
room.Host = &Client{
184+
ID: hostSession.UserID,
185+
Username: hostSession.Username,
186+
SessionToken: "", // Will be set on reconnection
187+
}
188+
}
189+
190+
s.rooms[room.Code] = room
191+
s.logger.Info("Restored room",
192+
zap.String("code", room.Code),
193+
zap.String("host_id", persistentRoom.HostID),
194+
zap.Int("disconnected_users", len(room.DisconnectedUsers)))
195+
}
196+
197+
// Restore sessions
198+
for _, ps := range state.Sessions {
199+
// Adjust disconnect time to account for shutdown duration
200+
session := &Session{
201+
UserID: ps.UserID,
202+
Username: ps.Username,
203+
RoomCode: ps.RoomCode,
204+
IsHost: ps.IsHost,
205+
DisconnectAt: ps.DisconnectAt.Add(shutdownDuration),
206+
}
207+
s.sessions[ps.Token] = session
208+
}
209+
210+
s.logger.Info("State restoration complete",
211+
zap.Int("rooms_restored", len(state.Rooms)),
212+
zap.Int("sessions_restored", len(state.Sessions)))
213+
214+
// Delete the state file after successful load
215+
if err := os.Remove(StateFile); err != nil {
216+
s.logger.Warn("Failed to remove state file", zap.Error(err))
217+
}
218+
219+
return nil
220+
}

0 commit comments

Comments
 (0)