Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions run-with-google-adk/google_mcp_security_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.mcp_tool.mcp_toolset import StdioServerParameters, StdioConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import StdioServerParameters
import os
import logging

Expand Down Expand Up @@ -62,24 +62,21 @@ def get_all_tools():

if os.environ.get("LOAD_SCC_MCP") == "Y":
scc_tools = MCPToolSetWithSchemaAccess(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
connection_params=StdioServerParameters(
command='uv',
args=[ "--directory",
uv_dir_prefix + "/scc",
"run",
"scc_mcp.py"
]
),
timeout=timeout),
tool_set_name="scc",
errlog=errlog_ae
)

if os.environ.get("LOAD_SECOPS_MCP") == "Y":
secops_tools = MCPToolSetWithSchemaAccess(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
connection_params=StdioServerParameters(
command='uv',
args=[ "--directory",
uv_dir_prefix + "/secops/secops_mcp",
Expand All @@ -89,15 +86,13 @@ def get_all_tools():
"server.py"
]
),
timeout=timeout),
tool_set_name="secops_mcp",
errlog=errlog_ae
)

if os.environ.get("LOAD_GTI_MCP") == "Y":
gti_tools = MCPToolSetWithSchemaAccess(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
connection_params=StdioServerParameters(
command='uv',
args=[ "--directory",
uv_dir_prefix + "/gti/gti_mcp",
Expand All @@ -107,16 +102,14 @@ def get_all_tools():
"server.py"
]
),
timeout=timeout),
tool_set_name="gti_mcp",
errlog=errlog_ae
)


if os.environ.get("LOAD_SECOPS_SOAR_MCP") == "Y":
secops_soar_tools = MCPToolSetWithSchemaAccess(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
connection_params=StdioServerParameters(
command='uv',
args=[ "--directory",
uv_dir_prefix + "/secops-soar/secops_soar_mcp",
Expand All @@ -128,7 +121,6 @@ def get_all_tools():
os.environ.get("SECOPS_INTEGRATIONS","CSV,OKTA")
]
),
timeout=timeout),
tool_set_name="secops_soar_mcp",
errlog=errlog_ae
)
Expand Down
2 changes: 1 addition & 1 deletion run-with-google-adk/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
google-cloud-aiplatform==1.97.0
markdown
uv
google-adk[eval]==1.3.0
google-adk[eval]==1.14.0
google-genai==1.20.0
pandas
51 changes: 14 additions & 37 deletions run-with-google-adk/utils_extensions_cbs_tools/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@
# limitations under the License.


# imports for overriding `get_tools`
from typing_extensions import override
from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource
from typing import List
from typing import Optional, Union, TextIO
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.tools.mcp_tool.mcp_tool import MCPTool, BaseTool
from google.adk.tools.mcp_tool.mcp_session_manager import StdioServerParameters, StdioConnectionParams, SseConnectionParams,StreamableHTTPConnectionParams
from mcp.types import ListToolsResult
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.mcp_tool.mcp_toolset import StdioServerParameters
from .cache import tools_cache
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, ToolPredicate
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
import sys
import logging

Expand All @@ -42,31 +38,23 @@ def __init__(
self,
*,
tool_set_name: str, # <-- new parameter
connection_params: Union[
StdioServerParameters,
StdioConnectionParams,
SseConnectionParams,
StreamableHTTPConnectionParams,
],
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
connection_params: StdioServerParameters,
tool_filter: Optional[List[str]] = None,
errlog: TextIO = sys.stderr,
):
super().__init__(
connection_params=connection_params,
tool_filter=tool_filter,
errlog=errlog
)
self.tool_set_name = tool_set_name
logging.info(f"MCPToolSetWithSchemaAccess initialized with tool_set_name: '{self.tool_set_name}'")
self._session = None

@retry_on_closed_resource("_reinitialize_session")
@override
async def get_tools(
self,
readonly_context: Optional[ReadonlyContext] = None,
) -> List[BaseTool]:
"""Return all tools in the toolset based on the provided context.
"""Return all tools in the toolset based on the provided context with caching.

Args:
readonly_context: Context used to filter tools available to the agent.
Expand All @@ -75,28 +63,17 @@ async def get_tools(
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
# Get session from session manager
if not self._session:
self._session = await self._mcp_session_manager.create_session()

# Check cache first
if self.tool_set_name in tools_cache.keys():
logging.info(f"Tools found in cache for toolset {self.tool_set_name}, returning them")
return tools_cache[self.tool_set_name]
else:
logging.info(f"No tools found in cache for toolset {self.tool_set_name}, loading")

tools_response: ListToolsResult = await self._session.list_tools()

# Apply filtering based on context and tool_filter
tools = []
for tool in tools_response.tools:
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session_manager=self._mcp_session_manager,
)

if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
logging.info(f"No tools found in cache for toolset {self.tool_set_name}, loading from parent")

# Get tools from parent class
tools = await super().get_tools(readonly_context)

# Cache the tools
tools_cache[self.tool_set_name] = tools
return tools
logging.info(f"Cached {len(tools)} tools for toolset {self.tool_set_name}")
return tools