1+ # Copyright (c) 2024- Datalayer, Inc.
2+ #
3+ # BSD 3-Clause License
4+
5+ """
6+ Tool Cache Module
7+
8+ Provides caching for expensive jupyter-mcp-tools queries to improve performance.
9+ """
10+
11+ import asyncio
12+ import time
13+ from typing import Dict , List , Optional , Any
14+ from dataclasses import dataclass
15+ from jupyter_mcp_server .log import logger
16+
17+
18+ @dataclass
19+ class CacheEntry :
20+ """Represents a cached entry with timestamp and data."""
21+ data : List [Dict [str , Any ]]
22+ timestamp : float
23+
24+ def is_expired (self , ttl_seconds : int ) -> bool :
25+ """Check if the cache entry has expired."""
26+ return time .time () - self .timestamp > ttl_seconds
27+
28+
29+ class ToolCache :
30+ """
31+ Cache for jupyter-mcp-tools data with TTL support.
32+
33+ This cache stores the complete tool data to avoid expensive get_tools() calls.
34+ """
35+
36+ def __init__ (self , default_ttl : int = 300 ): # 5 minutes default
37+ """
38+ Initialize the tool cache.
39+
40+ Args:
41+ default_ttl: Default time-to-live in seconds for cache entries
42+ """
43+ self ._cache : Dict [str , CacheEntry ] = {}
44+ self ._default_ttl = default_ttl
45+ self ._lock = asyncio .Lock ()
46+
47+ def _make_cache_key (self , base_url : str , query : str ) -> str :
48+ """Create a cache key from the request parameters."""
49+ # Use a simplified key based on base_url and query
50+ # Don't include token for security reasons
51+ return f"{ base_url } :{ query } "
52+
53+ async def get_tools (
54+ self ,
55+ base_url : str ,
56+ token : str ,
57+ query : str ,
58+ enabled_only : bool = False ,
59+ ttl_seconds : Optional [int ] = None ,
60+ fetch_func : Optional [Any ] = None
61+ ) -> List [Dict [str , Any ]]:
62+ """
63+ Get tools from cache or fetch them if not cached/expired.
64+
65+ Args:
66+ base_url: Jupyter server base URL
67+ token: Authentication token
68+ query: Search query for tools
69+ enabled_only: Whether to return only enabled tools
70+ ttl_seconds: Custom TTL for this request (overrides default)
71+ fetch_func: Function to call if cache miss (should be jupyter_mcp_tools.get_tools)
72+
73+ Returns:
74+ List of tool dictionaries
75+ """
76+ cache_key = self ._make_cache_key (base_url , query )
77+ ttl = ttl_seconds or self ._default_ttl
78+
79+ async with self ._lock :
80+ # Check if we have a valid cache entry
81+ if cache_key in self ._cache :
82+ entry = self ._cache [cache_key ]
83+ if not entry .is_expired (ttl ):
84+ logger .debug (f"Cache HIT for { cache_key } (age: { time .time () - entry .timestamp :.1f} s)" )
85+ return entry .data
86+ else :
87+ logger .debug (f"Cache EXPIRED for { cache_key } (age: { time .time () - entry .timestamp :.1f} s)" )
88+ del self ._cache [cache_key ]
89+ else :
90+ logger .debug (f"Cache MISS for { cache_key } " )
91+
92+ # Cache miss or expired - fetch fresh data
93+ if fetch_func is None :
94+ logger .warning ("No fetch function provided for cache miss - returning empty list" )
95+ return []
96+
97+ try :
98+ logger .info (f"Fetching fresh tools from jupyter-mcp-tools (query: '{ query } ')" )
99+ fresh_data = await fetch_func (
100+ base_url = base_url ,
101+ token = token ,
102+ query = query ,
103+ enabled_only = enabled_only
104+ )
105+
106+ # Store in cache
107+ async with self ._lock :
108+ self ._cache [cache_key ] = CacheEntry (
109+ data = fresh_data ,
110+ timestamp = time .time ()
111+ )
112+
113+ logger .info (f"Cached { len (fresh_data )} tools for key { cache_key } " )
114+ return fresh_data
115+
116+ except Exception as e :
117+ logger .error (f"Failed to fetch tools from jupyter-mcp-tools: { e } " )
118+ # Return empty list on error to prevent cascading failures
119+ return []
120+
121+ async def invalidate (self , base_url : str , query : str = None ):
122+ """
123+ Invalidate cache entries.
124+
125+ Args:
126+ base_url: Base URL to invalidate entries for
127+ query: Specific query to invalidate (if None, invalidates all for base_url)
128+ """
129+ async with self ._lock :
130+ if query is None :
131+ # Invalidate all entries for this base_url
132+ keys_to_remove = [
133+ key for key in self ._cache .keys ()
134+ if key .startswith (f"{ base_url } :" )
135+ ]
136+ for key in keys_to_remove :
137+ del self ._cache [key ]
138+ logger .info (f"Invalidated { len (keys_to_remove )} cache entries for { base_url } " )
139+ else :
140+ # Invalidate specific entry
141+ cache_key = self ._make_cache_key (base_url , query )
142+ if cache_key in self ._cache :
143+ del self ._cache [cache_key ]
144+ logger .info (f"Invalidated cache entry for { cache_key } " )
145+
146+ async def clear (self ):
147+ """Clear all cache entries."""
148+ async with self ._lock :
149+ count = len (self ._cache )
150+ self ._cache .clear ()
151+ logger .info (f"Cleared { count } cache entries" )
152+
153+ def get_cache_stats (self ) -> Dict [str , Any ]:
154+ """Get cache statistics."""
155+ return {
156+ "total_entries" : len (self ._cache ),
157+ "entries" : [
158+ {
159+ "key" : key ,
160+ "age_seconds" : time .time () - entry .timestamp ,
161+ "expired" : entry .is_expired (self ._default_ttl ),
162+ "data_count" : len (entry .data )
163+ }
164+ for key , entry in self ._cache .items ()
165+ ]
166+ }
167+
168+
169+ # Global cache instance
170+ _global_tool_cache = None
171+
172+
173+ def get_tool_cache () -> ToolCache :
174+ """Get the global tool cache instance."""
175+ global _global_tool_cache
176+ if _global_tool_cache is None :
177+ _global_tool_cache = ToolCache (default_ttl = 300 ) # 5 minutes
178+ return _global_tool_cache
0 commit comments