Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func TestEndToEnd(t *testing.T) {
var ct, st Transport = NewInMemoryTransports()

// Channels to check if notification callbacks happened.
// These test asynchronous sending of notifications after a small delay (see
// Server.sendNotification).
notificationChans := map[string]chan int{}
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe", "elicitation_complete"} {
notificationChans[name] = make(chan int, 1)
Expand Down Expand Up @@ -1671,14 +1673,15 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
}

func TestSynchronousNotifications(t *testing.T) {
var toolsChanged atomic.Bool
var toolsChanged atomic.Int32
clientOpts := &ClientOptions{
ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) {
toolsChanged.Store(true)
toolsChanged.Add(1)
},
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
if !toolsChanged.Load() {
return nil, fmt.Errorf("didn't get a tools changed notification")
// See the comment after "from server" below.
if n := toolsChanged.Load(); n != 1 {
return nil, fmt.Errorf("got %d tools-changed notification, wanted 1", n)
}
// TODO(rfindley): investigate the error returned from this test if
// CreateMessageResult is new(CreateMessageResult): it's a mysterious
Expand All @@ -1695,14 +1698,15 @@ func TestSynchronousNotifications(t *testing.T) {
},
}
server := NewServer(testImpl, serverOpts)
cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) {
addTool := func(s *Server) {
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
if !rootsChanged.Load() {
return nil, nil, fmt.Errorf("didn't get root change notification")
}
return new(CallToolResult), nil, nil
})
})
}
cs, ss, cleanup := basicClientServerConnection(t, client, server, addTool)
defer cleanup()

t.Run("from client", func(t *testing.T) {
Expand All @@ -1717,7 +1721,13 @@ func TestSynchronousNotifications(t *testing.T) {
})

t.Run("from server", func(t *testing.T) {
server.RemoveTools("tool")
// Despite all this tool-changed activity, we expect only one notification.
for range 10 {
server.RemoveTools("tool")
addTool(server)
}

time.Sleep(notificationDelay * 2) // Wait for delayed notification.
if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil {
t.Errorf("CreateMessage failed: %v", err)
}
Expand Down
62 changes: 41 additions & 21 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Server struct {
sendingMethodHandler_ MethodHandler
receivingMethodHandler_ MethodHandler
resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool
pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send
}

// ServerOptions is used to configure behavior of the server.
Expand Down Expand Up @@ -149,6 +150,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
resourceSubscriptions: make(map[string]map[*ServerSession]bool),
pendingNotifications: make(map[string]*time.Timer),
}
}

Expand All @@ -158,15 +160,13 @@ func (s *Server) AddPrompt(p *Prompt, h PromptHandler) {
// (It's possible an item was replaced with an identical one, but not worth checking.)
s.changeAndNotify(
notificationPromptListChanged,
&PromptListChangedParams{},
func() bool { s.prompts.add(&serverPrompt{p, h}); return true })
}

// RemovePrompts removes the prompts with the given names.
// It is not an error to remove a nonexistent prompt.
func (s *Server) RemovePrompts(names ...string) {
s.changeAndNotify(notificationPromptListChanged, &PromptListChangedParams{},
func() bool { return s.prompts.remove(names...) })
s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) })
}

// AddTool adds a [Tool] to the server, or replaces one with the same name.
Expand Down Expand Up @@ -235,8 +235,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
// (It's possible a tool was replaced with an identical one, but not worth checking.)
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
// TODO: Surface notify error here? best not, in case we need to batch.
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
func() bool { s.tools.add(st); return true })
s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true })
}

func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
Expand Down Expand Up @@ -419,14 +418,13 @@ func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
// RemoveTools removes the tools with the given names.
// It is not an error to remove a nonexistent tool.
func (s *Server) RemoveTools(names ...string) {
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
func() bool { return s.tools.remove(names...) })
s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) })
}

// AddResource adds a [Resource] to the server, or replaces one with the same URI.
// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme).
func (s *Server) AddResource(r *Resource, h ResourceHandler) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
s.changeAndNotify(notificationResourceListChanged,
func() bool {
if _, err := url.Parse(r.URI); err != nil {
panic(err) // url.Parse includes the URI in the error
Expand All @@ -439,14 +437,13 @@ func (s *Server) AddResource(r *Resource, h ResourceHandler) {
// RemoveResources removes the resources with the given URIs.
// It is not an error to remove a nonexistent resource.
func (s *Server) RemoveResources(uris ...string) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
func() bool { return s.resources.remove(uris...) })
s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) })
}

// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI.
// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme).
func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
s.changeAndNotify(notificationResourceListChanged,
func() bool {
// Validate the URI template syntax
_, err := uritemplate.New(t.URITemplate)
Expand All @@ -461,8 +458,7 @@ func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) {
// RemoveResourceTemplates removes the resource templates with the given URI templates.
// It is not an error to remove a nonexistent resource.
func (s *Server) RemoveResourceTemplates(uriTemplates ...string) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
func() bool { return s.resourceTemplates.remove(uriTemplates...) })
s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) })
}

func (s *Server) capabilities() *ServerCapabilities {
Expand Down Expand Up @@ -497,18 +493,43 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR
return s.opts.CompletionHandler(ctx, req)
}

// Map from notification name to its corresponding params. The params have no fields,
// so a single struct can be reused.
var changeNotificationParams = map[string]Params{
notificationToolListChanged: &ToolListChangedParams{},
notificationPromptListChanged: &PromptListChangedParams{},
notificationResourceListChanged: &ResourceListChangedParams{},
}

// How long to wait before sending a change notification.
const notificationDelay = 10 * time.Millisecond

// changeAndNotify is called when a feature is added or removed.
// It calls change, which should do the work and report whether a change actually occurred.
// If there was a change, it notifies a snapshot of the sessions.
func (s *Server) changeAndNotify(notification string, params Params, change func() bool) {
var sessions []*ServerSession
// Lock for the change, but not for the notification.
// If there was a change, it sets a timer to send a notification.
// This debounces change notifications: a single notification is sent after
// multiple changes occur in close proximity.
func (s *Server) changeAndNotify(notification string, change func() bool) {
s.mu.Lock()
defer s.mu.Unlock()
if change() {
sessions = slices.Clone(s.sessions)
// Stop the outstanding delayed call, if any.
if t := s.pendingNotifications[notification]; t != nil {
t.Stop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just call t.Reset() and probably avoid some allocation, no?

}
//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cruft?

s.pendingNotifications[notification] = time.AfterFunc(notificationDelay, func() { s.notifySessions(notification) })
}
s.mu.Unlock()
notifySessions(sessions, notification, params)
}

// notifySessions sends the notification n to all existing sessions.
// It is called asynchronously by changeAndNotify.
func (s *Server) notifySessions(n string) {
s.mu.Lock()
sessions := slices.Clone(s.sessions)
s.pendingNotifications[n] = nil
s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock.
notifySessions(sessions, n, changeNotificationParams[n])
}

// Sessions returns an iterator that yields the current set of server sessions.
Expand Down Expand Up @@ -1068,7 +1089,6 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli

resolved, err := schema.Resolve(nil)
if err != nil {
fmt.Printf(" resolve err: %s", err)
return nil, err
}
if err := resolved.Validate(res.Content); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P)
if sessions == nil {
return
}
// Notify with the background context, so the messages are sent on the
// standalone stream.
// TODO: make this timeout configurable, or call handleNotify asynchronously.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO now obsolete? Are all notifications asynchronous?

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand Down