Skip to content

Commit 9790dc2

Browse files
add get_tool caching (#177)
1 parent eecfd69 commit 9790dc2

File tree

5 files changed

+206
-13
lines changed

5 files changed

+206
-13
lines changed

docs/docs/jupyter/streamable-http/jupyter-extension/index.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Make sure you have the following packages installed in your environment. The col
99
```bash
1010
pip install "jupyter-mcp-server>=0.15.0" "jupyterlab==4.4.1" "jupyter-collaboration==4.0.2" "ipykernel"
1111
pip uninstall -y pycrdt datalayer_pycrdt
12-
pip install datalayer_pycrdt==0.12.17 jupyter_mcp_tools
12+
pip install datalayer_pycrdt==0.12.17
1313
```
1414

1515
:::tip

jupyter_mcp_server/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44

55
"""Jupyter MCP Server."""
66

7-
__version__ = "0.20.1"
7+
__version__ = "0.21.0"

jupyter_mcp_server/jupyter_extension/handlers.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ async def post(self):
133133
jupyter_tools_data = []
134134
try:
135135
from jupyter_mcp_tools import get_tools
136+
from jupyter_mcp_server.tool_cache import get_tool_cache
136137

137138
# Get the server's base URL dynamically from ServerApp
138139
context = get_server_context()
@@ -163,21 +164,31 @@ async def post(self):
163164

164165
logger.info(f"Looking for specific jupyter-mcp-tools: {allowed_jupyter_tools}")
165166

166-
# Try querying with broader terms since specific IDs don't work
167+
# Try querying with caching to avoid expensive repeated calls
167168
try:
168169
search_query = ",".join(allowed_jupyter_tools)
169170
logger.info(f"Searching jupyter-mcp-tools with query: '{search_query}' (allowed_tools: {allowed_jupyter_tools})")
170171

171-
# Query for notebook-related tools with shorter timeout
172-
# Note: jupyter-mcp-tools requires JupyterLab frontend to load and register tools via WebSocket
173-
jupyter_tools_data = await get_tools(
172+
# Use cached get_tools to avoid expensive repeated calls
173+
tool_cache = get_tool_cache()
174+
175+
# Create wrapper function that matches the expected signature
176+
async def get_tools_wrapper(**kwargs):
177+
# Add wait_timeout for handlers.py compatibility
178+
return await get_tools(
179+
wait_timeout=5, # Shorter timeout - if frontend isn't loaded, don't wait long
180+
**kwargs
181+
)
182+
183+
jupyter_tools_data = await tool_cache.get_tools(
174184
base_url=base_url,
175185
token=token,
176186
query=search_query,
177187
enabled_only=False,
178-
wait_timeout=5 # Shorter timeout - if frontend isn't loaded, don't wait long
188+
ttl_seconds=180, # 3 minutes for handlers (shorter than server.py)
189+
fetch_func=get_tools_wrapper # Use wrapper that includes wait_timeout
179190
)
180-
logger.info(f"Query returned {len(jupyter_tools_data)} tools")
191+
logger.info(f"Query returned {len(jupyter_tools_data)} tools (from cache or fresh)")
181192

182193
# Use the tools directly since query should return only what we want
183194
for tool in jupyter_tools_data:

jupyter_mcp_server/server.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,10 @@ async def get_registered_tools():
577577
if server_context.is_jupyterlab_mode():
578578
logger.info("JupyterLab mode enabled, loading selected jupyter-mcp-tools")
579579

580-
# Get tools from jupyter-mcp-tools extension
580+
# Get tools from jupyter-mcp-tools extension with caching
581581
try:
582582
from jupyter_mcp_tools import get_tools
583+
from jupyter_mcp_server.tool_cache import get_tool_cache
583584

584585
# Get the base_url and token from server context
585586
# In JUPYTER_SERVER mode, we should use the actual serverapp URL, not hardcoded localhost
@@ -602,18 +603,21 @@ async def get_registered_tools():
602603
"notebook_run-all-cells", # Run all cells in current notebook
603604
]
604605

605-
# Try querying with broader terms since specific IDs don't work
606+
# Try querying with caching to avoid expensive repeated calls
606607
try:
607608
search_query = ",".join(allowed_jupyter_tools)
608609
logger.info(f"Searching jupyter-mcp-tools with query: '{search_query}' (allowed_tools: {allowed_jupyter_tools})")
609610

610-
tools_data = await get_tools(
611+
# Use cached get_tools to avoid expensive repeated calls
612+
tool_cache = get_tool_cache()
613+
tools_data = await tool_cache.get_tools(
611614
base_url=base_url,
612615
token=token,
613616
query=search_query,
614-
enabled_only=False
617+
enabled_only=False,
618+
fetch_func=get_tools # Pass the actual get_tools function for cache misses
615619
)
616-
logger.info(f"Query returned {len(tools_data)} tools")
620+
logger.info(f"Query returned {len(tools_data)} tools (from cache or fresh)")
617621

618622
# Use the tools directly since query should return only what we want
619623
for tool in tools_data:

jupyter_mcp_server/tool_cache.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) 2024- Datalayer, Inc.
2+
#
3+
# BSD 3-Clause License
4+
5+
"""
6+
Tool Cache Module
7+
8+
Provides caching for expensive jupyter-mcp-tools queries to improve performance.
9+
"""
10+
11+
import asyncio
12+
import time
13+
from typing import Dict, List, Optional, Any
14+
from dataclasses import dataclass
15+
from jupyter_mcp_server.log import logger
16+
17+
18+
@dataclass
19+
class CacheEntry:
20+
"""Represents a cached entry with timestamp and data."""
21+
data: List[Dict[str, Any]]
22+
timestamp: float
23+
24+
def is_expired(self, ttl_seconds: int) -> bool:
25+
"""Check if the cache entry has expired."""
26+
return time.time() - self.timestamp > ttl_seconds
27+
28+
29+
class ToolCache:
30+
"""
31+
Cache for jupyter-mcp-tools data with TTL support.
32+
33+
This cache stores the complete tool data to avoid expensive get_tools() calls.
34+
"""
35+
36+
def __init__(self, default_ttl: int = 300): # 5 minutes default
37+
"""
38+
Initialize the tool cache.
39+
40+
Args:
41+
default_ttl: Default time-to-live in seconds for cache entries
42+
"""
43+
self._cache: Dict[str, CacheEntry] = {}
44+
self._default_ttl = default_ttl
45+
self._lock = asyncio.Lock()
46+
47+
def _make_cache_key(self, base_url: str, query: str) -> str:
48+
"""Create a cache key from the request parameters."""
49+
# Use a simplified key based on base_url and query
50+
# Don't include token for security reasons
51+
return f"{base_url}:{query}"
52+
53+
async def get_tools(
54+
self,
55+
base_url: str,
56+
token: str,
57+
query: str,
58+
enabled_only: bool = False,
59+
ttl_seconds: Optional[int] = None,
60+
fetch_func: Optional[Any] = None
61+
) -> List[Dict[str, Any]]:
62+
"""
63+
Get tools from cache or fetch them if not cached/expired.
64+
65+
Args:
66+
base_url: Jupyter server base URL
67+
token: Authentication token
68+
query: Search query for tools
69+
enabled_only: Whether to return only enabled tools
70+
ttl_seconds: Custom TTL for this request (overrides default)
71+
fetch_func: Function to call if cache miss (should be jupyter_mcp_tools.get_tools)
72+
73+
Returns:
74+
List of tool dictionaries
75+
"""
76+
cache_key = self._make_cache_key(base_url, query)
77+
ttl = ttl_seconds or self._default_ttl
78+
79+
async with self._lock:
80+
# Check if we have a valid cache entry
81+
if cache_key in self._cache:
82+
entry = self._cache[cache_key]
83+
if not entry.is_expired(ttl):
84+
logger.debug(f"Cache HIT for {cache_key} (age: {time.time() - entry.timestamp:.1f}s)")
85+
return entry.data
86+
else:
87+
logger.debug(f"Cache EXPIRED for {cache_key} (age: {time.time() - entry.timestamp:.1f}s)")
88+
del self._cache[cache_key]
89+
else:
90+
logger.debug(f"Cache MISS for {cache_key}")
91+
92+
# Cache miss or expired - fetch fresh data
93+
if fetch_func is None:
94+
logger.warning("No fetch function provided for cache miss - returning empty list")
95+
return []
96+
97+
try:
98+
logger.info(f"Fetching fresh tools from jupyter-mcp-tools (query: '{query}')")
99+
fresh_data = await fetch_func(
100+
base_url=base_url,
101+
token=token,
102+
query=query,
103+
enabled_only=enabled_only
104+
)
105+
106+
# Store in cache
107+
async with self._lock:
108+
self._cache[cache_key] = CacheEntry(
109+
data=fresh_data,
110+
timestamp=time.time()
111+
)
112+
113+
logger.info(f"Cached {len(fresh_data)} tools for key {cache_key}")
114+
return fresh_data
115+
116+
except Exception as e:
117+
logger.error(f"Failed to fetch tools from jupyter-mcp-tools: {e}")
118+
# Return empty list on error to prevent cascading failures
119+
return []
120+
121+
async def invalidate(self, base_url: str, query: str = None):
122+
"""
123+
Invalidate cache entries.
124+
125+
Args:
126+
base_url: Base URL to invalidate entries for
127+
query: Specific query to invalidate (if None, invalidates all for base_url)
128+
"""
129+
async with self._lock:
130+
if query is None:
131+
# Invalidate all entries for this base_url
132+
keys_to_remove = [
133+
key for key in self._cache.keys()
134+
if key.startswith(f"{base_url}:")
135+
]
136+
for key in keys_to_remove:
137+
del self._cache[key]
138+
logger.info(f"Invalidated {len(keys_to_remove)} cache entries for {base_url}")
139+
else:
140+
# Invalidate specific entry
141+
cache_key = self._make_cache_key(base_url, query)
142+
if cache_key in self._cache:
143+
del self._cache[cache_key]
144+
logger.info(f"Invalidated cache entry for {cache_key}")
145+
146+
async def clear(self):
147+
"""Clear all cache entries."""
148+
async with self._lock:
149+
count = len(self._cache)
150+
self._cache.clear()
151+
logger.info(f"Cleared {count} cache entries")
152+
153+
def get_cache_stats(self) -> Dict[str, Any]:
154+
"""Get cache statistics."""
155+
return {
156+
"total_entries": len(self._cache),
157+
"entries": [
158+
{
159+
"key": key,
160+
"age_seconds": time.time() - entry.timestamp,
161+
"expired": entry.is_expired(self._default_ttl),
162+
"data_count": len(entry.data)
163+
}
164+
for key, entry in self._cache.items()
165+
]
166+
}
167+
168+
169+
# Global cache instance
170+
_global_tool_cache = None
171+
172+
173+
def get_tool_cache() -> ToolCache:
174+
"""Get the global tool cache instance."""
175+
global _global_tool_cache
176+
if _global_tool_cache is None:
177+
_global_tool_cache = ToolCache(default_ttl=300) # 5 minutes
178+
return _global_tool_cache

0 commit comments

Comments
 (0)