diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index a7fe57b7..35061803 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -42,6 +42,7 @@ import ( "github.com/livekit/protocol/rpc" lksip "github.com/livekit/protocol/sip" "github.com/livekit/protocol/tracer" + "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/traceid" "github.com/livekit/psrpc" lksdk "github.com/livekit/server-sdk-go/v2" @@ -64,6 +65,8 @@ const ( inviteOKRetryAttempts = 5 inviteOKRetryAttemptsNoACK = 2 inviteOkAckLateTimeout = inviteOkRetryIntervalMax + + inviteCredentialValidity = 60 * time.Minute // Allow reuse of credentials for 1h ) var errNoACK = errors.New("no ACK received for 200 OK") @@ -134,23 +137,50 @@ func (s *Server) getCallInfo(id string) *inboundCallInfo { return c } -func (s *Server) getInvite(sipCallID string) *inProgressInvite { - s.imu.Lock() - defer s.imu.Unlock() - for i := range s.inProgressInvites { - if s.inProgressInvites[i].sipCallID == sipCallID { - return s.inProgressInvites[i] +func (s *Server) cleanupInvites() { + ticker := time.NewTicker(5 * time.Minute) // Periodic cleanup every 5 minutes + defer ticker.Stop() + for { + select { + case <-s.closing.Watch(): + return + case <-ticker.C: + s.imu.Lock() + for it := s.inviteTimeoutQueue.IterateRemoveAfter(inviteCredentialValidity); it.Next(); { + key := it.Item().Value + delete(s.inProgressInvites, key) + } + s.imu.Unlock() } } - if len(s.inProgressInvites) >= digestLimit { - s.inProgressInvites = s.inProgressInvites[1:] +} + +func (s *Server) getInvite(sipCallID, toTag, fromTag string) *inProgressInvite { + key := dialogKey{ + sipCallID: sipCallID, + toTag: toTag, + fromTag: fromTag, + } + + s.imu.RLock() + is, exists := s.inProgressInvites[key] + s.imu.RUnlock() + if !exists { + s.imu.Lock() + is, exists = s.inProgressInvites[key] + if !exists { + is = &inProgressInvite{sipCallID: sipCallID, timeoutLink: utils.TimeoutQueueItem[dialogKey]{Value: key}} + s.inProgressInvites[key] = is + } + s.imu.Unlock() } - is := &inProgressInvite{sipCallID: sipCallID} - s.inProgressInvites = append(s.inProgressInvites, is) + + // Always reset the timeout link, whether just created or not + s.inviteTimeoutQueue.Reset(&is.timeoutLink) return is } -func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string) (ok bool) { +func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Request, tx sip.ServerTransaction, from, username, password string, inviteState *inProgressInvite) (ok bool) { log = log.WithValues( "username", username, "passwordHash", hashPassword(password), @@ -171,14 +201,6 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re _ = tx.Respond(sip.NewResponseFromRequest(req, 100, "Processing", nil)) } - // Extract SIP Call ID for tracking in-progress invites - sipCallID := "" - if h := req.CallID(); h != nil { - sipCallID = h.Value() - } - inviteState := s.getInvite(sipCallID) - log = log.WithValues("inviteStateSipCallID", sipCallID) - h := req.GetHeader("Proxy-Authorization") if h == nil { inviteState.challenge = digest.Challenge{ @@ -230,7 +252,6 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re // Check if we have a valid challenge state if inviteState.challenge.Realm == "" { log.Warnw("No challenge state found for authentication attempt", errors.New("missing challenge state"), - "sipCallID", sipCallID, "expectedRealm", UserAgent, ) _ = tx.Respond(sip.NewResponseFromRequest(req, 401, "Bad credentials", nil)) @@ -305,20 +326,18 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE s.log.Errorw("cannot parse source IP", err, "fromIP", src) return psrpc.NewError(psrpc.MalformedRequest, errors.Wrap(err, "cannot parse source IP")) } - callID := lksip.NewCallID() - tid := traceid.FromGUID(callID) + sipCallID := legCallIDFromReq(req) tr := callTransportFromReq(req) legTr := legTransportFromReq(req) log := s.log.WithValues( - "callID", callID, - "traceID", tid.String(), + "sipCallID", sipCallID, "fromIP", src.Addr(), "toIP", req.Destination(), "transport", tr, ) var call *inboundCall - cc := s.newInbound(log, LocalTag(callID), s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string { + cc := s.newInbound(log, s.ContactURI(legTr), req, tx, func(headers map[string]string) map[string]string { c := call if c == nil || len(c.attrsToHdr) == 0 { return headers @@ -331,10 +350,9 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE }) log = LoggerWithParams(log, cc) log = LoggerWithHeaders(log, cc) - cc.log = log - log.Infow("processing invite") if err := cc.ValidateInvite(); err != nil { + log.Errorw("invalid invite", err) if s.conf.HideInboundPort { cc.Drop() } else { @@ -342,14 +360,43 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE } return psrpc.NewError(psrpc.InvalidArgument, errors.Wrap(err, "invite validation failed")) } + + // Establish ID + fromTag, _ := req.From().Params.Get("tag") // always exists, via ValidateInvite() check + toParams := req.To().Params // To() always exists, via ValidateInvite() check + if toParams == nil { + toParams = sip.NewParams() + req.To().Params = toParams + } + toTag, ok := toParams.Get("tag") + if !ok { + // No to-tag on the invite means we need to generate one per RFC 3261 section 12. + // Generate a new to-tag early, to make sure both INVITES have the same ID. + toTag = utils.NewGuid("") + toParams.Add("tag", toTag) + } + inviteProgress := s.getInvite(sipCallID, toTag, fromTag) + callID := inviteProgress.lkCallID + if callID == "" { + callID = lksip.NewCallID() + inviteProgress.lkCallID = callID + } + cc.id = LocalTag(callID) + tid := traceid.FromGUID(sipCallID) + + log = log.WithValues("callID", callID) + log = log.WithValues("traceID", tid.String()) + cc.log = log + log.Infow("processing invite") + ctx, span := tracer.Start(ctx, "Server.onInvite") defer span.End() s.cmu.RLock() - existing := s.byCallID[cc.SIPCallID()] + existing := s.byCallID[sipCallID] s.cmu.RUnlock() if existing != nil && existing.cc.InviteCSeq() < cc.InviteCSeq() { - log.Infow("accepting reinvite", "sipCallID", existing.cc.ID(), "content-type", req.ContentType(), "content-length", req.ContentLength()) + log.Infow("accepting reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength()) existing.log().Infow("reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength(), "cseq", cc.InviteCSeq()) cc.AcceptAsKeepAlive(existing.cc.OwnSDP()) return nil @@ -376,7 +423,7 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE callInfo := &rpc.SIPCall{ LkCallId: callID, - SipCallId: cc.SIPCallID(), + SipCallId: sipCallID, SourceIp: src.Addr().String(), Address: ToSIPUri("", cc.Address()), From: ToSIPUri("", from), @@ -447,15 +494,15 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE // We will send password request anyway, so might as well signal that the progress is made. cc.Processing() } - s.getCallInfo(cc.SIPCallID()).countInvite(log, req) - if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password) { + s.getCallInfo(sipCallID).countInvite(log, req) + if !s.handleInviteAuth(tid, log, req, tx, from.User, r.Username, r.Password, inviteProgress) { cmon.InviteErrorShort("unauthorized") // handleInviteAuth will generate the SIP Response as needed return psrpc.NewErrorf(psrpc.PermissionDenied, "invalid credentials were provided") } // ok case AuthAccept: - s.getCallInfo(cc.SIPCallID()).countInvite(log, req) + s.getCallInfo(sipCallID).countInvite(log, req) // ok } @@ -1366,11 +1413,10 @@ func (c *inboundCall) transferCall(ctx context.Context, transferTo string, heade } -func (s *Server) newInbound(log logger.Logger, id LocalTag, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction, getHeaders setHeadersFunc) *sipInbound { +func (s *Server) newInbound(log logger.Logger, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction, getHeaders setHeadersFunc) *sipInbound { c := &sipInbound{ log: log, s: s, - id: id, invite: invite, inviteTx: inviteTx, legTr: legTransportFromReq(invite), diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index 27b2cd6a..62107a3b 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -843,7 +843,7 @@ authLoop: if err != nil { return nil, fmt.Errorf("invalid challenge %q: %w", challengeStr, err) } - toHeader := resp.To() + toHeader = resp.To() if toHeader == nil { return nil, errors.New("no 'To' header on Response") } diff --git a/pkg/sip/protocol.go b/pkg/sip/protocol.go index 3a233fa2..ae5ee5ed 100644 --- a/pkg/sip/protocol.go +++ b/pkg/sip/protocol.go @@ -173,6 +173,13 @@ func legTransportFromReq(req *sip.Request) Transport { return "" } +func legCallIDFromReq(req *sip.Request) string { + if callID := req.CallID(); callID != nil { + return callID.Value() + } + return "" +} + func transportPort(c *config.Config, t Transport) int { if t == TransportTLS { if tc := c.TLS; tc != nil { diff --git a/pkg/sip/server.go b/pkg/sip/server.go index bedfcd9e..5212f1d6 100644 --- a/pkg/sip/server.go +++ b/pkg/sip/server.go @@ -35,6 +35,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/traceid" "github.com/livekit/sipgo" "github.com/livekit/sipgo/sip" @@ -44,8 +45,7 @@ import ( ) const ( - UserAgent = "LiveKit" - digestLimit = 500 + UserAgent = "LiveKit" ) const ( @@ -127,18 +127,25 @@ type Handler interface { OnSessionEnd(ctx context.Context, callIdentifier *CallIdentifier, callInfo *livekit.SIPCallInfo, reason string) } +type dialogKey struct { + sipCallID string + toTag string + fromTag string +} + type Server struct { - log logger.Logger - mon *stats.Monitor - region string - sipSrv *sipgo.Server - getIOClient GetIOInfoClient - getRoom GetRoomFunc - sipListeners []io.Closer - sipUnhandled RequestHandler - - imu sync.Mutex - inProgressInvites []*inProgressInvite + log logger.Logger + mon *stats.Monitor + region string + sipSrv *sipgo.Server + getIOClient GetIOInfoClient + getRoom GetRoomFunc + sipListeners []io.Closer + sipUnhandled RequestHandler + inviteTimeoutQueue utils.TimeoutQueue[dialogKey] + + imu sync.RWMutex + inProgressInvites map[dialogKey]*inProgressInvite closing core.Fuse cmu sync.RWMutex @@ -159,8 +166,10 @@ type Server struct { } type inProgressInvite struct { - sipCallID string - challenge digest.Challenge + sipCallID string + challenge digest.Challenge + lkCallID string // SCL_* LiveKit call ID assigned to this dialog + timeoutLink utils.TimeoutQueueItem[dialogKey] } type ServerOption func(s *Server) @@ -178,15 +187,16 @@ func NewServer(region string, conf *config.Config, log logger.Logger, mon *stats log = logger.GetLogger() } s := &Server{ - log: log, - conf: conf, - region: region, - mon: mon, - getIOClient: getIOClient, - getRoom: DefaultGetRoomFunc, - byRemoteTag: make(map[RemoteTag]*inboundCall), - byLocalTag: make(map[LocalTag]*inboundCall), - byCallID: make(map[string]*inboundCall), + log: log, + conf: conf, + region: region, + mon: mon, + getIOClient: getIOClient, + getRoom: DefaultGetRoomFunc, + inProgressInvites: make(map[dialogKey]*inProgressInvite), + byRemoteTag: make(map[RemoteTag]*inboundCall), + byLocalTag: make(map[LocalTag]*inboundCall), + byCallID: make(map[string]*inboundCall), } for _, option := range options { option(s) @@ -330,6 +340,9 @@ func (s *Server) Start(agent *sipgo.UserAgent, sc *ServiceConfig, tlsConf *tls.C } } + // Start the cleanup task + go s.cleanupInvites() + return nil } diff --git a/pkg/sip/service_test.go b/pkg/sip/service_test.go index 5b21fb7c..8cf137e8 100644 --- a/pkg/sip/service_test.go +++ b/pkg/sip/service_test.go @@ -782,3 +782,163 @@ func TestCANCELSendsBothResponses(t *testing.T) { // Verify we received the critical 487 response require.True(t, invite487Received, "Should have received 487 Request Terminated response to INVITE when CANCEL is sent") } + +// TestSameCallIDForAuthFlow verifies that the same LiveKit call ID is assigned to both +// the initial INVITE (without auth) and the subsequent INVITE (with auth) +func TestSameCallIDForAuthFlow(t *testing.T) { + const ( + fromUser = "test@example.com" + toUser = "agent@example.com" + username = "testuser" + password = "testpass" + callID = "same-call-id@test.com" + fromTag = "fixed-from-tag-12345" + ) + + var capturedCallIDs []string + var mu sync.Mutex + + h := &TestHandler{ + GetAuthCredentialsFunc: func(ctx context.Context, call *rpc.SIPCall) (AuthInfo, error) { + // Capture the LiveKit call ID from the first request + mu.Lock() + capturedCallIDs = append(capturedCallIDs, call.LkCallId) + mu.Unlock() + + return AuthInfo{ + Result: AuthPassword, + Username: username, + Password: password, + }, nil + }, + DispatchCallFunc: func(ctx context.Context, info *CallInfo) CallDispatch { + return CallDispatch{ + Result: DispatchNoRuleReject, + // No room config needed for reject + } + }, + OnSessionEndFunc: func(ctx context.Context, callIdentifier *CallIdentifier, callInfo *livekit.SIPCallInfo, reason string) { + // No-op for tests to avoid async logging issues + }, + } + + // Create service with authentication enabled + sipPort := rand.Intn(testPortSIPMax-testPortSIPMin) + testPortSIPMin + localIP, err := config.GetLocalIP() + require.NoError(t, err) + + sipServerAddress := fmt.Sprintf("%s:%d", localIP, sipPort) + + mon, err := stats.NewMonitor(&config.Config{MaxCpuUtilization: 0.9}) + require.NoError(t, err) + + // Use a no-op logger to avoid panics from async logging after test completion + log := logger.LogRLogger(logr.Discard()) + s, err := NewService("", &config.Config{ + HideInboundPort: false, // Enable authentication + SIPPort: sipPort, + SIPPortListen: sipPort, + RTPPort: rtcconfig.PortRange{Start: testPortRTPMin, End: testPortRTPMax}, + }, mon, log, func(projectID string) rpc.IOInfoClient { return nil }) + require.NoError(t, err) + require.NotNil(t, s) + t.Cleanup(s.Stop) + + s.SetHandler(h) + require.NoError(t, s.Start()) + + sipUserAgent, err := sipgo.NewUA( + sipgo.WithUserAgent(fromUser), + sipgo.WithUserAgentLogger(slog.New(logger.ToSlogHandler(s.log))), + ) + require.NoError(t, err) + + sipClient, err := sipgo.NewClient(sipUserAgent) + require.NoError(t, err) + + offer, err := sdp.NewOffer(localIP, 0xB0B, sdp.EncryptionNone) + require.NoError(t, err) + offerData, err := offer.SDP.Marshal() + require.NoError(t, err) + + inviteFromHeader := sip.FromHeader{ + DisplayName: fromUser, + Address: sip.Uri{User: fromUser, Host: sipServerAddress}, + Params: sip.NewParams().Add("tag", fromTag), // Key bit here + } + + // Create first INVITE request (without auth) + inviteRecipient := sip.Uri{User: toUser, Host: sipServerAddress} + inviteRequest1 := sip.NewRequest(sip.INVITE, inviteRecipient) + inviteRequest1.SetDestination(sipServerAddress) + inviteRequest1.SetBody(offerData) + inviteRequest1.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) + inviteRequest1.AppendHeader(sip.NewHeader("Call-ID", callID)) + inviteRequest1.AppendHeader(&inviteFromHeader) + + tx1, err := sipClient.TransactionRequest(inviteRequest1) + require.NoError(t, err) + t.Cleanup(tx1.Terminate) + + // Should receive 100 Trying first, then 407 Unauthorized + res1 := getResponseOrFail(t, tx1) + require.Equal(t, sip.StatusCode(100), res1.StatusCode, "First request should receive 100 Trying") + res1 = getResponseOrFail(t, tx1) + require.Equal(t, sip.StatusCode(407), res1.StatusCode, "First request should receive 407 Unauthorized") + + // Get the To tag from the 407 response + toHeader := res1.To() + require.NotNil(t, toHeader, "407 response should have To header") + _, ok := toHeader.Params.Get("tag") + require.True(t, ok, "407 response To header should have tag parameter") + + // Get the challenge from first response + authHeader1 := res1.GetHeader("Proxy-Authenticate") + require.NotNil(t, authHeader1, "First response should have Proxy-Authenticate header") + challenge1 := authHeader1.Value() + + // Parse the challenge to extract nonce and realm + challenge, err := digest.ParseChallenge(challenge1) + require.NoError(t, err, "Should be able to parse challenge") + + // Compute the digest response using the challenge and credentials + cred, err := digest.Digest(challenge, digest.Options{ + Method: "INVITE", + URI: inviteRecipient.String(), + Username: username, + Password: password, + }) + require.NoError(t, err, "Should be able to compute digest response") + + // Create second INVITE request (with auth) using the SAME Call-ID, From tag, and To tag + inviteRequest2 := sip.NewRequest(sip.INVITE, inviteRecipient) + inviteRequest2.SetDestination(sipServerAddress) + inviteRequest2.SetBody(offerData) + inviteRequest2.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) + inviteRequest2.AppendHeader(sip.NewHeader("Call-ID", callID)) + inviteRequest2.AppendHeader(sip.NewHeader("Proxy-Authorization", cred.String())) + inviteRequest2.AppendHeader(&inviteFromHeader) + inviteRequest2.AppendHeader(toHeader) + + tx2, err := sipClient.TransactionRequest(inviteRequest2) + require.NoError(t, err) + t.Cleanup(tx2.Terminate) + + // Should receive 100 Trying first, then proceed with authentication + res2 := getResponseOrFail(t, tx2) + require.Equal(t, sip.StatusCode(100), res2.StatusCode, "Second request should receive 100 Trying") + + // Wait a bit for the handler to be called + time.Sleep(100 * time.Millisecond) + + // Verify we captured exactly 2 call IDs + mu.Lock() + require.Len(t, capturedCallIDs, 2, "Should have captured 2 call IDs") + require.Equal(t, capturedCallIDs[0], capturedCallIDs[1], "Both requests should have the same LiveKit call ID") + require.NotEmpty(t, capturedCallIDs[0], "Call ID should not be empty") + require.Contains(t, capturedCallIDs[0], "SCL_", "Call ID should have SCL_ prefix") + mu.Unlock() + + t.Logf("First call ID: %s", capturedCallIDs[0]) + t.Logf("Second call ID: %s", capturedCallIDs[1]) +}