diff --git a/CHANGELOG.md b/CHANGELOG.md index 942e25a4..3f05d517 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Added +- Added `--disable-builtin-tools` flag to the `serve` command that allows disabling all built-in tools (load-component, unload-component, list-components, get-policy, grant/revoke permissions, search-components, reset-permission). When enabled, only loaded component tools will be available through the MCP server - Comprehensive Docker documentation and Dockerfile for running Wassette in containers with enhanced security isolation, including examples for mounting components, secrets, configuration files, and production deployment patterns with Docker Compose - `rust-toolchain.toml` file specifying Rust 1.90 as the stable toolchain version, ensuring consistent Rust version across development environments and CI/CD pipelines - AI agent development guides (`AGENTS.md` and `Claude.md`) that consolidate development guidelines from `.github/instructions/` into accessible documentation for AI agents working on the project diff --git a/crates/mcp-server/src/tools.rs b/crates/mcp-server/src/tools.rs index 4669e7bf..99ae6d7d 100644 --- a/crates/mcp-server/src/tools.rs +++ b/crates/mcp-server/src/tools.rs @@ -21,11 +21,16 @@ const COMPONENT_LIST: &str = include_str!("../../../component-registry.json"); /// Handles a request to list available tools. #[instrument(skip(lifecycle_manager))] -pub async fn handle_tools_list(lifecycle_manager: &LifecycleManager) -> Result { +pub async fn handle_tools_list( + lifecycle_manager: &LifecycleManager, + disable_builtin_tools: bool, +) -> Result { debug!("Handling tools list request"); let mut tools = get_component_tools(lifecycle_manager).await?; - tools.extend(get_builtin_tools()); + if !disable_builtin_tools { + tools.extend(get_builtin_tools()); + } debug!(num_tools = %tools.len(), "Retrieved tools"); let response = rmcp::model::ListToolsResult { @@ -36,41 +41,79 @@ pub async fn handle_tools_list(lifecycle_manager: &LifecycleManager) -> Result bool { + matches!( + name, + "load-component" + | "unload-component" + | "list-components" + | "get-policy" + | "grant-storage-permission" + | "grant-network-permission" + | "grant-environment-variable-permission" + | "revoke-storage-permission" + | "revoke-network-permission" + | "revoke-environment-variable-permission" + | "search-components" + | "reset-permission" + ) +} + /// Handles a tool call request. #[instrument(skip_all, fields(method_name = %req.name))] pub async fn handle_tools_call( req: CallToolRequestParam, lifecycle_manager: &LifecycleManager, server_peer: Peer, + disable_builtin_tools: bool, ) -> Result { info!("Handling tool call"); - let result = match req.name.as_ref() { - "load-component" => handle_load_component(&req, lifecycle_manager, server_peer).await, - "unload-component" => handle_unload_component(&req, lifecycle_manager, server_peer).await, - "list-components" => handle_list_components(lifecycle_manager).await, - "get-policy" => handle_get_policy(&req, lifecycle_manager).await, - "grant-storage-permission" => { - handle_grant_storage_permission(&req, lifecycle_manager).await - } - "grant-network-permission" => { - handle_grant_network_permission(&req, lifecycle_manager).await - } - "grant-environment-variable-permission" => { - handle_grant_environment_variable_permission(&req, lifecycle_manager).await - } - "revoke-storage-permission" => { - handle_revoke_storage_permission(&req, lifecycle_manager).await - } - "revoke-network-permission" => { - handle_revoke_network_permission(&req, lifecycle_manager).await - } - "revoke-environment-variable-permission" => { - handle_revoke_environment_variable_permission(&req, lifecycle_manager).await + let result = if disable_builtin_tools && is_builtin_tool(req.name.as_ref()) { + // When builtin tools are disabled, reject calls to builtin tools + Err(anyhow::anyhow!("Built-in tools are disabled")) + } else { + // Handle builtin tools (if enabled) or component calls + match req.name.as_ref() { + "load-component" if !disable_builtin_tools => { + handle_load_component(&req, lifecycle_manager, server_peer).await + } + "unload-component" if !disable_builtin_tools => { + handle_unload_component(&req, lifecycle_manager, server_peer).await + } + "list-components" if !disable_builtin_tools => { + handle_list_components(lifecycle_manager).await + } + "get-policy" if !disable_builtin_tools => { + handle_get_policy(&req, lifecycle_manager).await + } + "grant-storage-permission" if !disable_builtin_tools => { + handle_grant_storage_permission(&req, lifecycle_manager).await + } + "grant-network-permission" if !disable_builtin_tools => { + handle_grant_network_permission(&req, lifecycle_manager).await + } + "grant-environment-variable-permission" if !disable_builtin_tools => { + handle_grant_environment_variable_permission(&req, lifecycle_manager).await + } + "revoke-storage-permission" if !disable_builtin_tools => { + handle_revoke_storage_permission(&req, lifecycle_manager).await + } + "revoke-network-permission" if !disable_builtin_tools => { + handle_revoke_network_permission(&req, lifecycle_manager).await + } + "revoke-environment-variable-permission" if !disable_builtin_tools => { + handle_revoke_environment_variable_permission(&req, lifecycle_manager).await + } + "search-components" if !disable_builtin_tools => { + handle_search_component(&req, lifecycle_manager).await + } + "reset-permission" if !disable_builtin_tools => { + handle_reset_permission(&req, lifecycle_manager).await + } + _ => handle_component_call(&req, lifecycle_manager).await, } - "search-components" => handle_search_component(&req, lifecycle_manager).await, - "reset-permission" => handle_reset_permission(&req, lifecycle_manager).await, - _ => handle_component_call(&req, lifecycle_manager).await, }; if let Err(ref e) = result { diff --git a/src/commands.rs b/src/commands.rs index 3f6616cf..01f56eb3 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -70,6 +70,11 @@ pub struct Serve { #[arg(long = "env-file")] #[serde(skip)] pub env_file: Option, + + /// Disable built-in tools (load-component, unload-component, list-components, etc.) + #[arg(long)] + #[serde(default)] + pub disable_builtin_tools: bool, } #[derive(Args, Debug, Clone, Serialize, Deserialize, Default)] diff --git a/src/config.rs b/src/config.rs index 8398513a..9e737296 100644 --- a/src/config.rs +++ b/src/config.rs @@ -134,6 +134,7 @@ mod tests { transport: Default::default(), env_vars: vec![], env_file: None, + disable_builtin_tools: false, } } @@ -143,6 +144,7 @@ mod tests { transport: Default::default(), env_vars: vec![], env_file: None, + disable_builtin_tools: false, } } diff --git a/src/main.rs b/src/main.rs index 81dbc4de..34c0db10 100644 --- a/src/main.rs +++ b/src/main.rs @@ -194,6 +194,7 @@ const BIND_ADDRESS: &str = "127.0.0.1:9001"; pub struct McpServer { lifecycle_manager: LifecycleManager, peer: Arc>>>, + disable_builtin_tools: bool, } /// Handle CLI tool commands by creating appropriate tool call requests @@ -271,6 +272,7 @@ async fn create_lifecycle_manager(plugin_dir: Option) -> Result Self { + /// * `disable_builtin_tools` - Whether to disable built-in tools + pub fn new(lifecycle_manager: LifecycleManager, disable_builtin_tools: bool) -> Self { Self { lifecycle_manager, peer: Arc::new(Mutex::new(None)), + disable_builtin_tools, } } @@ -354,8 +358,15 @@ Key points: // Store peer on first request self.store_peer_if_empty(peer_clone.clone()); + let disable_builtin_tools = self.disable_builtin_tools; Box::pin(async move { - let result = handle_tools_call(params, &self.lifecycle_manager, peer_clone).await; + let result = handle_tools_call( + params, + &self.lifecycle_manager, + peer_clone, + disable_builtin_tools, + ) + .await; match result { Ok(value) => serde_json::from_value(value).map_err(|e| { ErrorData::parse_error(format!("Failed to parse result: {e}"), None) @@ -373,8 +384,9 @@ Key points: // Store peer on first request self.store_peer_if_empty(ctx.peer.clone()); + let disable_builtin_tools = self.disable_builtin_tools; Box::pin(async move { - let result = handle_tools_list(&self.lifecycle_manager).await; + let result = handle_tools_list(&self.lifecycle_manager, disable_builtin_tools).await; match result { Ok(value) => serde_json::from_value(value).map_err(|e| { ErrorData::parse_error(format!("Failed to parse result: {e}"), None) @@ -519,7 +531,7 @@ async fn main() -> Result<()> { .build() .await?; - let server = McpServer::new(lifecycle_manager.clone()); + let server = McpServer::new(lifecycle_manager.clone(), cfg.disable_builtin_tools); // Start background component loading let server_clone = server.clone(); diff --git a/tests/transport_integration_test.rs b/tests/transport_integration_test.rs index 82b2e7ed..26663f47 100644 --- a/tests/transport_integration_test.rs +++ b/tests/transport_integration_test.rs @@ -1002,3 +1002,164 @@ async fn test_grant_permission_network_basic() -> Result<()> { Ok(()) } + +#[test(tokio::test)] +async fn test_disable_builtin_tools() -> Result<()> { + // Create a temporary directory for this test to avoid loading existing components + let temp_dir = tempfile::tempdir()?; + let plugin_dir_arg = format!("--plugin-dir={}", temp_dir.path().display()); + + // Get the path to the built binary + let binary_path = std::env::current_dir() + .context("Failed to get current directory")? + .join("target/debug/wassette"); + + // Start the server with stdio transport and disable-builtin-tools flag + let mut child = tokio::process::Command::new(&binary_path) + .args(["serve", &plugin_dir_arg, "--disable-builtin-tools"]) + .env("RUST_LOG", "off") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .context("Failed to start wassette with disabled builtin tools")?; + + let stdin = child.stdin.take().context("Failed to get stdin handle")?; + let stdout = child.stdout.take().context("Failed to get stdout handle")?; + let stderr = child.stderr.take().context("Failed to get stderr handle")?; + + let mut stdin = stdin; + let mut stdout = BufReader::new(stdout); + let mut stderr = BufReader::new(stderr); + + // Give the server time to start + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Check if the process is still running + if let Ok(Some(status)) = child.try_wait() { + let mut stderr_output = String::new(); + let _ = stderr.read_line(&mut stderr_output).await; + return Err(anyhow::anyhow!( + "Server process exited with status: {:?}, stderr: {}", + status, + stderr_output + )); + } + + // Send MCP initialize request + let initialize_request = r#"{"jsonrpc": "2.0", "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test-client", "version": "1.0.0"}}, "id": 1} +"#; + + stdin.write_all(initialize_request.as_bytes()).await?; + stdin.flush().await?; + + // Read and verify response + let mut response_line = String::new(); + tokio::time::timeout( + Duration::from_secs(10), + stdout.read_line(&mut response_line), + ) + .await + .context("Timeout waiting for initialize response")? + .context("Failed to read initialize response")?; + + let response: serde_json::Value = + serde_json::from_str(&response_line).context("Failed to parse initialize response")?; + + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 1); + assert!(response["result"].is_object()); + + // Send initialized notification + let initialized_notification = r#"{"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}} +"#; + + stdin.write_all(initialized_notification.as_bytes()).await?; + stdin.flush().await?; + + // Send list_tools request + let list_tools_request = r#"{"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 2} +"#; + + stdin.write_all(list_tools_request.as_bytes()).await?; + stdin.flush().await?; + + // Read and verify tools list response + let mut tools_response_line = String::new(); + tokio::time::timeout( + Duration::from_secs(10), + stdout.read_line(&mut tools_response_line), + ) + .await + .context("Timeout waiting for tools/list response")? + .context("Failed to read tools/list response")?; + + let tools_response: serde_json::Value = serde_json::from_str(&tools_response_line) + .context("Failed to parse tools/list response")?; + + // Verify the tools response structure + assert_eq!(tools_response["jsonrpc"], "2.0"); + assert_eq!(tools_response["id"], 2); + assert!(tools_response["result"].is_object()); + assert!(tools_response["result"]["tools"].is_array()); + + // Verify that built-in tools are NOT present when disabled + let tools = &tools_response["result"]["tools"].as_array().unwrap(); + let tool_names: Vec = tools + .iter() + .map(|tool| tool["name"].as_str().unwrap_or("").to_string()) + .collect(); + + assert!( + !tool_names.contains(&"load-component".to_string()), + "load-component should not be present when builtin tools are disabled" + ); + assert!( + !tool_names.contains(&"unload-component".to_string()), + "unload-component should not be present when builtin tools are disabled" + ); + assert!( + !tool_names.contains(&"list-components".to_string()), + "list-components should not be present when builtin tools are disabled" + ); + assert!( + !tool_names.contains(&"get-policy".to_string()), + "get-policy should not be present when builtin tools are disabled" + ); + + // Try to call a builtin tool and verify it fails + let call_tool_request = r#"{"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "list-components", "arguments": {}}, "id": 3} +"#; + + stdin.write_all(call_tool_request.as_bytes()).await?; + stdin.flush().await?; + + // Read and verify call tool response + let mut call_response_line = String::new(); + tokio::time::timeout( + Duration::from_secs(10), + stdout.read_line(&mut call_response_line), + ) + .await + .context("Timeout waiting for tools/call response")? + .context("Failed to read tools/call response")?; + + let call_response: serde_json::Value = + serde_json::from_str(&call_response_line).context("Failed to parse tools/call response")?; + + // Verify that the tool call failed + assert_eq!(call_response["jsonrpc"], "2.0"); + assert_eq!(call_response["id"], 3); + assert!(call_response["result"].is_object()); + let result = &call_response["result"]; + assert_eq!( + result["isError"].as_bool().unwrap_or(false), + true, + "Tool call should have failed" + ); + + // Clean up + child.kill().await.ok(); + + Ok(()) +}